mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
330 Commits
v0.23.2
...
jagrit06/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
400f8457ea | ||
|
|
dfb5022eab | ||
|
|
ac207ce7aa | ||
|
|
fce53b61d6 | ||
|
|
8ae4a76308 | ||
|
|
7fde1b6a1e | ||
|
|
aa7b47481a | ||
|
|
56be773610 | ||
|
|
a9bdd67baa | ||
|
|
f2adb5638d | ||
|
|
728d4db582 | ||
|
|
db5c7efcf6 | ||
|
|
7bb96e4249 | ||
|
|
fa89f0b150 | ||
|
|
ca973d1e83 | ||
|
|
828c5f1137 | ||
|
|
7d86a5c108 | ||
|
|
0b807893a7 | ||
|
|
6ad0889c8a | ||
|
|
737dd6d1ac | ||
|
|
aaf78f4c6b | ||
|
|
8831064493 | ||
|
|
be9bc96da4 | ||
|
|
86258f292f | ||
|
|
b26d88591c | ||
|
|
86c6a15571 | ||
|
|
8b25ce62d5 | ||
|
|
da5912e4f2 | ||
|
|
daafee676f | ||
|
|
d32519c8ee | ||
|
|
b405591249 | ||
|
|
3bf81ed1bd | ||
|
|
2204182bba | ||
|
|
3628e5d497 | ||
|
|
a0ae49d397 | ||
|
|
254476718b | ||
|
|
3adba92ebe | ||
|
|
ef631d63af | ||
|
|
970dbe8e25 | ||
|
|
641be9463b | ||
|
|
ab0e608862 | ||
|
|
1588659062 | ||
|
|
b9e88fb976 | ||
|
|
4ad53414dd | ||
|
|
d1165b215e | ||
|
|
dcb8319f3d | ||
|
|
5597fa089c | ||
|
|
9acec364c2 | ||
|
|
7d9d6ef456 | ||
|
|
6f5874a2f2 | ||
|
|
70dc336785 | ||
|
|
4e504039f5 | ||
|
|
d1f4d291e8 | ||
|
|
e1840853ce | ||
|
|
0f5ce173da | ||
|
|
588854195f | ||
|
|
28d068bce6 | ||
|
|
d107d8d495 | ||
|
|
1e496ddb82 | ||
|
|
74eccbf3fa | ||
|
|
08638223ca | ||
|
|
56cc858af9 | ||
|
|
f55c4ed1d6 | ||
|
|
93d70419e7 | ||
|
|
63f663d9c6 | ||
|
|
84b4d96efa | ||
|
|
aec67f2fa6 | ||
|
|
deee214a95 | ||
|
|
45adec102c | ||
|
|
31fc530c76 | ||
|
|
fbb3f65a1a | ||
|
|
6b1b8ea91b | ||
|
|
b2273733ea | ||
|
|
f409b229a4 | ||
|
|
30571e2326 | ||
|
|
d7734edd9f | ||
|
|
2ba69bc8fa | ||
|
|
cb349a291c | ||
|
|
f0a0b077a0 | ||
|
|
49114f28ab | ||
|
|
e7d2ebadd2 | ||
|
|
e569803d7c | ||
|
|
d34f887abc | ||
|
|
5201df5030 | ||
|
|
2d3c26c565 | ||
|
|
6325f60d52 | ||
|
|
42cc9cfbc7 | ||
|
|
8347575ba1 | ||
|
|
b6eec20260 | ||
|
|
0eb035b4b1 | ||
|
|
afb9817599 | ||
|
|
8fb3e7a26c | ||
|
|
8c7bc30ce4 | ||
|
|
85873cb162 | ||
|
|
e14ee12491 | ||
|
|
8b9a3f3cea | ||
|
|
fb4e8b896b | ||
|
|
2ca533b279 | ||
|
|
4a9b29a875 | ||
|
|
a4fcc893cd | ||
|
|
9d10239af7 | ||
|
|
19facd4b20 | ||
|
|
f5299f72cd | ||
|
|
0e0d9ac522 | ||
|
|
8917022deb | ||
|
|
ec0d5db67b | ||
|
|
e76e9b87f0 | ||
|
|
cfb6a244ea | ||
|
|
58f3860306 | ||
|
|
dd4f53db63 | ||
|
|
3d5e17e507 | ||
|
|
33bf1a244b | ||
|
|
772f471ff2 | ||
|
|
2c11d10f8d | ||
|
|
656ed7f780 | ||
|
|
81bb9a2a9e | ||
|
|
5adf185f86 | ||
|
|
c9a9180584 | ||
|
|
76831ed83d | ||
|
|
b3d7b85376 | ||
|
|
cad5c0241c | ||
|
|
b8022c578a | ||
|
|
bc53f8293f | ||
|
|
c552ff2451 | ||
|
|
4fda5fbdf9 | ||
|
|
580776559b | ||
|
|
a14aaa7c9d | ||
|
|
a6d780154f | ||
|
|
6871e2eeb7 | ||
|
|
8402a2acf4 | ||
|
|
fddb6933e1 | ||
|
|
c8b4787e4e | ||
|
|
2188199ff8 | ||
|
|
aa07429bad | ||
|
|
918761a25a | ||
|
|
a4fc671d3e | ||
|
|
f5f65ef48c | ||
|
|
c2dd81a8aa | ||
|
|
d7e680ffe4 | ||
|
|
c371baf53a | ||
|
|
ccf78f566c | ||
|
|
c9fa68664a | ||
|
|
c35f4d089a | ||
|
|
8590c0941e | ||
|
|
095163b8d1 | ||
|
|
99c33d011d | ||
|
|
62fecf3e13 | ||
|
|
7c4eb5d03e | ||
|
|
bae9a6b404 | ||
|
|
004c1d8ef2 | ||
|
|
7ebb2e0193 | ||
|
|
9ce77798b1 | ||
|
|
f8bad60609 | ||
|
|
5866b3857b | ||
|
|
1ca616844b | ||
|
|
2e8cf0b450 | ||
|
|
24f89173d1 | ||
|
|
c6a20b427a | ||
|
|
a5ac9244c4 | ||
|
|
c763fe1be0 | ||
|
|
52dc8c8cd5 | ||
|
|
aede70e81d | ||
|
|
85a8beb5e4 | ||
|
|
0bb89e9e5f | ||
|
|
5685ceb3c7 | ||
|
|
0408ba0a76 | ||
|
|
cbad6c3093 | ||
|
|
1b021f6984 | ||
|
|
95b7551d65 | ||
|
|
db5a7c6192 | ||
|
|
6ef2f67e7f | ||
|
|
f76ee1ffd2 | ||
|
|
54a71f270a | ||
|
|
55b4062dd8 | ||
|
|
79071bfba4 | ||
|
|
7774b87cbd | ||
|
|
35c87741cf | ||
|
|
4cbe605214 | ||
|
|
ab8883dd55 | ||
|
|
eebe73001a | ||
|
|
0359bf02c9 | ||
|
|
237f9e58a8 | ||
|
|
8576e6fe36 | ||
|
|
0654543dcc | ||
|
|
48ef3e74e2 | ||
|
|
7d4b378952 | ||
|
|
7ff5c41e06 | ||
|
|
602f43e3d1 | ||
|
|
a2cadb8218 | ||
|
|
c1eb9d05d9 | ||
|
|
cf6c939e86 | ||
|
|
130df35e1b | ||
|
|
0751263dec | ||
|
|
eca2f3eb97 | ||
|
|
3aa9cf3f9e | ||
|
|
8f3d208dce | ||
|
|
caaa3f1f8c | ||
|
|
659a51919f | ||
|
|
6661387066 | ||
|
|
a7fae8a176 | ||
|
|
0cae0bdac8 | ||
|
|
5a1a5d5ed1 | ||
|
|
1683975acf | ||
|
|
af705590ac | ||
|
|
825124af8f | ||
|
|
9c5e7da507 | ||
|
|
481349495b | ||
|
|
9daa6b003f | ||
|
|
a3a632d567 | ||
|
|
e496c5a4b4 | ||
|
|
ea890d8710 | ||
|
|
aa5d84f102 | ||
|
|
f1606486d2 | ||
|
|
87720a8908 | ||
|
|
bb6565ef14 | ||
|
|
7bb063bcb3 | ||
|
|
b36dd472bb | ||
|
|
167b759a38 | ||
|
|
99b9868859 | ||
|
|
6b2d5448f2 | ||
|
|
eaf709b83e | ||
|
|
f0e70afff0 | ||
|
|
86984cad68 | ||
|
|
fbc89e3ced | ||
|
|
38c1e720c2 | ||
|
|
600e87e03c | ||
|
|
3836445241 | ||
|
|
1d2c9d6a07 | ||
|
|
e8ac6bd2f5 | ||
|
|
fdadc4f22c | ||
|
|
79b527f45f | ||
|
|
dc4eada7f0 | ||
|
|
70ebc3b598 | ||
|
|
b13f2aed16 | ||
|
|
5f04c0f818 | ||
|
|
55935ccae7 | ||
|
|
b529515eb1 | ||
|
|
3cde719eb7 | ||
|
|
5de6d94a90 | ||
|
|
99eefd2ec0 | ||
|
|
e9e268336b | ||
|
|
7275ac7523 | ||
|
|
c4189a38e4 | ||
|
|
68d1b3256b | ||
|
|
9c6953bda7 | ||
|
|
ef7ece9851 | ||
|
|
ddaa4b7dcb | ||
|
|
dfae2c6989 | ||
|
|
515f104926 | ||
|
|
9ecefd56db | ||
|
|
e5d35aa187 | ||
|
|
00794c42bc | ||
|
|
08a1bf3f10 | ||
|
|
60c4154346 | ||
|
|
f2c85308c1 | ||
|
|
1a28b69ee2 | ||
|
|
ba09f01ce8 | ||
|
|
6cf48872b7 | ||
|
|
7b3b8fa000 | ||
|
|
ec5e2aae61 | ||
|
|
86389bf970 | ||
|
|
3290bfa690 | ||
|
|
8777fd104f | ||
|
|
c41f7565ed | ||
|
|
9ba81e3da4 | ||
|
|
c23888acd7 | ||
|
|
f98ce25ab9 | ||
|
|
de5f38fd48 | ||
|
|
ec2854b13a | ||
|
|
90823d2938 | ||
|
|
5f5770e3a2 | ||
|
|
28f39e9038 | ||
|
|
b2d2b37888 | ||
|
|
fe597e141c | ||
|
|
72ca1539e0 | ||
|
|
13b26775f1 | ||
|
|
05d7118561 | ||
|
|
98b901ad66 | ||
|
|
5580b47291 | ||
|
|
bc62932984 | ||
|
|
a6b5d6e759 | ||
|
|
a8931306e1 | ||
|
|
fecdb8717e | ||
|
|
916fd273ea | ||
|
|
0da8506552 | ||
|
|
eda7a7b43e | ||
|
|
022eabb734 | ||
|
|
aba899cef8 | ||
|
|
6a40e1c176 | ||
|
|
9307b2ab8b | ||
|
|
522d8d3917 | ||
|
|
a84cc0123f | ||
|
|
f018e248cd | ||
|
|
cfd7237a80 | ||
|
|
4eef8102c9 | ||
|
|
69e4dd506b | ||
|
|
25814a9458 | ||
|
|
2a980a76ce | ||
|
|
d343782c8b | ||
|
|
4e1994e9d7 | ||
|
|
65a38c452b | ||
|
|
7b7e2352cd | ||
|
|
1177d28395 | ||
|
|
005e7efa64 | ||
|
|
b42d13ec84 | ||
|
|
9adcd1a650 | ||
|
|
3c164fca8c | ||
|
|
95e335db7b | ||
|
|
f90206ad74 | ||
|
|
3779150750 | ||
|
|
0a9777aa5c | ||
|
|
45ad06aac8 | ||
|
|
c6ea2ba329 | ||
|
|
2770a10240 | ||
|
|
d2a94f9e6a | ||
|
|
32da94507a | ||
|
|
736a340478 | ||
|
|
117e1355a2 | ||
|
|
3c3e558c60 | ||
|
|
cffceda6ee | ||
|
|
048805ad2c | ||
|
|
d14c9fe7ea | ||
|
|
5db90ce822 | ||
|
|
d699cc1330 | ||
|
|
c4230747a1 | ||
|
|
5245f12a46 | ||
|
|
a198b2787e | ||
|
|
04edad8c59 | ||
|
|
392b3060b0 | ||
|
|
85b34d59bc |
@@ -7,15 +7,9 @@ parameters:
|
|||||||
nightly_build:
|
nightly_build:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
weekly_build:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
test_release:
|
test_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
linux_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
build_documentation:
|
||||||
@@ -24,8 +18,8 @@ jobs:
|
|||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
macos:
|
macos:
|
||||||
xcode: "15.2.0"
|
xcode: "16.2.0"
|
||||||
resource_class: macos.m1.medium.gen1
|
resource_class: m2pro.medium
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
@@ -38,7 +32,7 @@ jobs:
|
|||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install -r docs/requirements.txt
|
pip install -r docs/requirements.txt
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
pip install . -v
|
||||||
- when:
|
- when:
|
||||||
condition:
|
condition:
|
||||||
not: << parameters.upload-docs >>
|
not: << parameters.upload-docs >>
|
||||||
@@ -70,9 +64,9 @@ jobs:
|
|||||||
git push -f origin gh-pages
|
git push -f origin gh-pages
|
||||||
|
|
||||||
linux_build_and_test:
|
linux_build_and_test:
|
||||||
docker:
|
machine:
|
||||||
- image: cimg/python:3.9
|
image: ubuntu-2204:current
|
||||||
|
resource_class: large
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
@@ -84,33 +78,35 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
pip install --upgrade cmake
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
pip install nanobind==2.4.0
|
export NEEDRESTART_MODE=a
|
||||||
pip install numpy
|
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
uv venv
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
uv pip install cmake
|
||||||
python3 setup.py build_ext --inplace
|
uv pip install -e ".[dev]" -v
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
python3 setup.py develop
|
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
uv pip install typing_extensions
|
||||||
pip install typing_extensions
|
uv run --no-project setup.py generate_stubs
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
python3 -m unittest discover python/tests -v
|
source .venv/bin/activate
|
||||||
|
python -m unittest discover python/tests -v
|
||||||
|
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||||
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
|
source .venv/bin/activate
|
||||||
mkdir -p build && cd build
|
mkdir -p build && cd build
|
||||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||||
make -j `nproc`
|
make -j `nproc`
|
||||||
@@ -122,58 +118,63 @@ jobs:
|
|||||||
parameters:
|
parameters:
|
||||||
xcode_version:
|
xcode_version:
|
||||||
type: string
|
type: string
|
||||||
default: "15.2.0"
|
default: "16.2.0"
|
||||||
|
macosx_deployment_target:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
macos:
|
macos:
|
||||||
xcode: << parameters.xcode_version >>
|
xcode: << parameters.xcode_version >>
|
||||||
resource_class: macos.m1.medium.gen1
|
environment:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||||
|
resource_class: m2pro.medium
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
brew install python@3.9
|
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||||
brew install openmpi
|
brew install openmpi uv
|
||||||
python3.9 -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install numpy
|
|
||||||
pip install torch
|
|
||||||
pip install tensorflow
|
|
||||||
pip install unittest-xml-reporting
|
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
uv venv --python 3.9
|
||||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
uv pip install \
|
||||||
|
nanobind==2.4.0 \
|
||||||
|
cmake \
|
||||||
|
numpy \
|
||||||
|
torch \
|
||||||
|
tensorflow \
|
||||||
|
unittest-xml-reporting
|
||||||
|
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||||
|
uv pip install -e . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
uv pip install typing_extensions
|
||||||
pip install typing_extensions
|
uv run --no-project setup.py generate_stubs
|
||||||
python setup.py generate_stubs
|
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||||
|
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||||
- run:
|
- run:
|
||||||
name: Build example extension
|
name: Build example extension
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
cd examples/extensions
|
cd examples/extensions
|
||||||
pip install -r requirements.txt
|
uv pip install -r requirements.txt
|
||||||
python setup.py build_ext -j8
|
uv run --no-project setup.py build_ext --inplace
|
||||||
|
uv run --no-project python test.py
|
||||||
- store_test_results:
|
- store_test_results:
|
||||||
path: test-results
|
path: test-results
|
||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||||
- run:
|
- run:
|
||||||
name: Run CPP tests
|
name: Run CPP tests
|
||||||
@@ -182,7 +183,7 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Build small binary
|
name: Build small binary
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source .venv/bin/activate
|
||||||
cd build/
|
cd build/
|
||||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
-DBUILD_SHARED_LIBS=ON \
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
@@ -194,13 +195,60 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Run Python tests with JIT
|
name: Run Python tests with JIT
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||||
pip install -e . -v
|
uv pip install -e .
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||||
METAL_DEBUG_ERROR_MODE=0 \
|
METAL_DEBUG_ERROR_MODE=0 \
|
||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
uv run --no-project python -m xmlrunner discover \
|
||||||
|
-v python/tests \
|
||||||
|
-o test-results/gpu_jit
|
||||||
|
|
||||||
|
cuda_build_and_test:
|
||||||
|
parameters:
|
||||||
|
image_date:
|
||||||
|
type: string
|
||||||
|
default: "2023.11.1"
|
||||||
|
machine:
|
||||||
|
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||||
|
resource_class: gpu.nvidia.small.gen2
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- restore_cache:
|
||||||
|
keys:
|
||||||
|
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||||
|
- run:
|
||||||
|
name: Install dependencies
|
||||||
|
command: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install libcudnn9-dev-cuda-12
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||||
|
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||||
|
rm -rf ccache-4.11.3-linux-x86_64
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
- run:
|
||||||
|
name: Install Python package
|
||||||
|
command: |
|
||||||
|
uv venv
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
uv pip install -e ".[dev]" -v
|
||||||
|
- run:
|
||||||
|
name: Run Python tests
|
||||||
|
command: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||||
|
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||||
|
- run:
|
||||||
|
name: CCache report
|
||||||
|
command: |
|
||||||
|
ccache --show-stats
|
||||||
|
ccache --zero-stats
|
||||||
|
ccache --max-size 400MB
|
||||||
|
ccache --cleanup
|
||||||
|
- save_cache:
|
||||||
|
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||||
|
paths:
|
||||||
|
- /home/circleci/.cache/ccache
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
parameters:
|
parameters:
|
||||||
@@ -209,13 +257,18 @@ jobs:
|
|||||||
default: "3.9"
|
default: "3.9"
|
||||||
xcode_version:
|
xcode_version:
|
||||||
type: string
|
type: string
|
||||||
default: "15.2.0"
|
default: "16.2.0"
|
||||||
build_env:
|
build_env:
|
||||||
type: string
|
type: string
|
||||||
default: ""
|
default: ""
|
||||||
|
macosx_deployment_target:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
macos:
|
macos:
|
||||||
xcode: << parameters.xcode_version >>
|
xcode: << parameters.xcode_version >>
|
||||||
resource_class: macos.m1.medium.gen1
|
resource_class: m2pro.medium
|
||||||
|
environment:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
@@ -236,8 +289,7 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
DEV_RELEASE=1 \
|
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
pip install . -v
|
pip install . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
@@ -249,9 +301,18 @@ jobs:
|
|||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
<< parameters.build_env >> \
|
python setup.py clean --all
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||||
python -m build -w
|
- when:
|
||||||
|
condition:
|
||||||
|
equal: ["3.9", << parameters.python_version >>]
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Build common package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
python setup.py clean --all
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||||
- when:
|
- when:
|
||||||
condition: << parameters.build_env >>
|
condition: << parameters.build_env >>
|
||||||
steps:
|
steps:
|
||||||
@@ -268,52 +329,100 @@ jobs:
|
|||||||
python_version:
|
python_version:
|
||||||
type: string
|
type: string
|
||||||
default: "3.9"
|
default: "3.9"
|
||||||
extra_env:
|
build_env:
|
||||||
type: string
|
type: string
|
||||||
default: "DEV_RELEASE=1"
|
default: ""
|
||||||
docker:
|
machine:
|
||||||
- image: ubuntu:20.04
|
image: ubuntu-2204:current
|
||||||
|
resource_class: large
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
name: Build wheel
|
name: Build wheel
|
||||||
command: |
|
command: |
|
||||||
PYTHON=python<< parameters.python_version >>
|
PYTHON=python<< parameters.python_version >>
|
||||||
apt-get update
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
apt-get upgrade -y
|
export NEEDRESTART_MODE=a
|
||||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
sudo apt-get update
|
||||||
apt-get install -y apt-utils
|
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||||
apt-get install -y software-properties-common
|
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||||
add-apt-repository -y ppa:deadsnakes/ppa
|
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
|
||||||
apt-get install -y build-essential git
|
|
||||||
$PYTHON -m venv env
|
$PYTHON -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.4.0
|
|
||||||
pip install --upgrade setuptools
|
|
||||||
pip install numpy
|
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
pip install patchelf
|
pip install patchelf
|
||||||
pip install build
|
pip install build
|
||||||
pip install twine
|
pip install twine
|
||||||
<< parameters.extra_env >> \
|
<< parameters.build_env >> pip install ".[dev]" -v
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
pip install . -v
|
|
||||||
pip install typing_extensions
|
pip install typing_extensions
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
python setup.py clean --all
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||||
python -m build --wheel
|
bash python/scripts/repair_linux.sh
|
||||||
auditwheel show dist/*
|
- when:
|
||||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
condition:
|
||||||
|
equal: ["3.9", << parameters.python_version >>]
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Build common package
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
python setup.py clean --all
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||||
|
python -m build -w
|
||||||
|
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||||
|
- when:
|
||||||
|
condition: << parameters.build_env >>
|
||||||
|
steps:
|
||||||
|
- run:
|
||||||
|
name: Upload packages
|
||||||
|
command: |
|
||||||
|
source env/bin/activate
|
||||||
|
twine upload wheelhouse/*.whl
|
||||||
|
- store_artifacts:
|
||||||
|
path: wheelhouse/
|
||||||
|
|
||||||
|
build_cuda_release:
|
||||||
|
parameters:
|
||||||
|
build_env:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
|
machine:
|
||||||
|
image: ubuntu-2204:current
|
||||||
|
resource_class: large
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Build wheel
|
||||||
|
command: |
|
||||||
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
|
export NEEDRESTART_MODE=a
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||||
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
|
sudo apt-get install zip
|
||||||
|
pip install auditwheel
|
||||||
|
pip install patchelf
|
||||||
|
pip install build
|
||||||
|
pip install twine
|
||||||
|
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||||
|
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
python -m build -w
|
||||||
|
bash python/scripts/repair_cuda.sh
|
||||||
|
- when:
|
||||||
|
condition: << parameters.build_env >>
|
||||||
|
steps:
|
||||||
- run:
|
- run:
|
||||||
name: Upload package
|
name: Upload package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
twine upload wheelhouse/*.whl
|
||||||
twine upload wheelhouse/*
|
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
@@ -325,21 +434,23 @@ workflows:
|
|||||||
pattern: "^(?!pull/)[-\\w]+$"
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
value: << pipeline.git.branch >>
|
value: << pipeline.git.branch >>
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- mac_build_and_test:
|
- mac_build_and_test:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
|
- cuda_build_and_test:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
image_date: ["2023.11.1", "2025.05.1"]
|
||||||
- build_documentation
|
- build_documentation
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- not: << pipeline.parameters.nightly_build >>
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
- not: << pipeline.parameters.weekly_build >>
|
|
||||||
- not: << pipeline.parameters.test_release >>
|
- not: << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_release:
|
- build_release:
|
||||||
@@ -351,8 +462,70 @@ workflows:
|
|||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
xcode_version: ["15.0.0", "15.2.0"]
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
xcode_version: ["16.2.0", "15.0.0"]
|
||||||
|
exclude:
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "PYPI_RELEASE=1"
|
||||||
- build_documentation:
|
- build_documentation:
|
||||||
filters:
|
filters:
|
||||||
tags:
|
tags:
|
||||||
@@ -360,6 +533,25 @@ workflows:
|
|||||||
branches:
|
branches:
|
||||||
ignore: /.*/
|
ignore: /.*/
|
||||||
upload-docs: true
|
upload-docs: true
|
||||||
|
- build_linux_release:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
- build_cuda_release:
|
||||||
|
filters:
|
||||||
|
tags:
|
||||||
|
only: /^v.*/
|
||||||
|
branches:
|
||||||
|
ignore: /.*/
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
|
|
||||||
prb:
|
prb:
|
||||||
when:
|
when:
|
||||||
@@ -375,9 +567,14 @@ workflows:
|
|||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
macosx_deployment_target: ["13.5", "14.0"]
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
|
- cuda_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
image_date: ["2023.11.1", "2025.05.1"]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
@@ -388,27 +585,140 @@ workflows:
|
|||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
xcode_version: ["15.0.0", "15.2.0"]
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
weekly_build:
|
xcode_version: ["16.2.0", "15.0.0"]
|
||||||
|
exclude:
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
- build_linux_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
|
- build_cuda_release
|
||||||
|
|
||||||
|
build_dev_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
- << pipeline.parameters.weekly_build >>
|
- << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_release:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||||
build_env: ["DEV_RELEASE=1"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
linux_test_release:
|
xcode_version: ["16.2.0", "15.0.0"]
|
||||||
when:
|
exclude:
|
||||||
and:
|
- macosx_deployment_target: "13.5"
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
xcode_version: "16.2.0"
|
||||||
- << pipeline.parameters.linux_release >>
|
python_version: "3.9"
|
||||||
jobs:
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "13.5"
|
||||||
|
xcode_version: "16.2.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "14.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.9"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.10"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.11"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.12"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
|
- macosx_deployment_target: "15.0"
|
||||||
|
xcode_version: "15.0.0"
|
||||||
|
python_version: "3.13"
|
||||||
|
build_env: "DEV_RELEASE=1"
|
||||||
- build_linux_release:
|
- build_linux_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
- build_cuda_release:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
build_env: ["DEV_RELEASE=1"]
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,6 +36,7 @@ share/python-wheels/
|
|||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
|
uv.lock
|
||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ if(NOT MLX_VERSION)
|
|||||||
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
||||||
set(_patch ${CMAKE_MATCH_1})
|
set(_patch ${CMAKE_MATCH_1})
|
||||||
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
||||||
|
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
||||||
else()
|
else()
|
||||||
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
||||||
${MLX_VERSION})
|
${MLX_VERSION})
|
||||||
@@ -33,15 +34,16 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
|||||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||||
|
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||||
|
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||||
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
|
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
message(
|
message(
|
||||||
@@ -64,10 +66,17 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|||||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
endif()
|
||||||
|
|
||||||
|
if(MLX_USE_CCACHE)
|
||||||
|
find_program(CCACHE_PROGRAM ccache)
|
||||||
|
if(CCACHE_PROGRAM)
|
||||||
|
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Lib -----------------------------
|
# ----------------------------- Lib -----------------------------
|
||||||
@@ -77,7 +86,6 @@ include(FetchContent)
|
|||||||
cmake_policy(SET CMP0135 NEW)
|
cmake_policy(SET CMP0135 NEW)
|
||||||
|
|
||||||
add_library(mlx)
|
add_library(mlx)
|
||||||
set_target_properties(mlx PROPERTIES COMPILE_WARNING_AS_ERROR ON)
|
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
set(METAL_LIB "-framework Metal")
|
set(METAL_LIB "-framework Metal")
|
||||||
@@ -85,6 +93,10 @@ if(MLX_BUILD_METAL)
|
|||||||
set(QUARTZ_LIB "-framework QuartzCore")
|
set(QUARTZ_LIB "-framework QuartzCore")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
enable_language(CUDA)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
@@ -214,23 +226,13 @@ else()
|
|||||||
set(MLX_BUILD_ACCELERATE OFF)
|
set(MLX_BUILD_ACCELERATE OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_package(MPI)
|
message(STATUS "Downloading json")
|
||||||
if(MPI_FOUND)
|
FetchContent_Declare(
|
||||||
execute_process(
|
json
|
||||||
COMMAND zsh "-c" "mpirun --version"
|
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||||
OUTPUT_VARIABLE MPI_VERSION
|
FetchContent_MakeAvailable(json)
|
||||||
ERROR_QUIET)
|
target_include_directories(
|
||||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
|
||||||
elseif(MPI_VERSION STREQUAL "")
|
|
||||||
set(MPI_FOUND FALSE)
|
|
||||||
message(
|
|
||||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
|
||||||
else()
|
|
||||||
set(MPI_FOUND FALSE)
|
|
||||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||||
|
|
||||||
@@ -238,12 +240,19 @@ target_include_directories(
|
|||||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
$<INSTALL_INTERFACE:include>)
|
$<INSTALL_INTERFACE:include>)
|
||||||
|
|
||||||
FetchContent_Declare(
|
# Do not add mlx_EXPORTS define for shared library.
|
||||||
|
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||||
|
|
||||||
|
if(USE_SYSTEM_FMT)
|
||||||
|
find_package(fmt REQUIRED)
|
||||||
|
else()
|
||||||
|
FetchContent_Declare(
|
||||||
fmt
|
fmt
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
GIT_TAG 10.2.1
|
GIT_TAG 10.2.1
|
||||||
EXCLUDE_FROM_ALL)
|
EXCLUDE_FROM_ALL)
|
||||||
FetchContent_MakeAvailable(fmt)
|
FetchContent_MakeAvailable(fmt)
|
||||||
|
endif()
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||||
|
|||||||
@@ -17,11 +17,11 @@ possible.
|
|||||||
|
|
||||||
You can also run the formatters manually as follows:
|
You can also run the formatters manually as follows:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
clang-format -i file.cpp
|
clang-format -i file.cpp
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```shell
|
||||||
black file.py
|
black file.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
|
include mlx.pc.in
|
||||||
recursive-include mlx/ *
|
recursive-include mlx/ *
|
||||||
|
include cmake/*
|
||||||
include python/src/*
|
include python/src/*
|
||||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||||
|
|||||||
17
README.md
17
README.md
@@ -68,18 +68,23 @@ in the documentation.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||||
|
macOS, run:
|
||||||
|
|
||||||
**With `pip`**:
|
```bash
|
||||||
|
|
||||||
```
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
```
|
```
|
||||||
|
|
||||||
**With `conda`**:
|
To install the CUDA backend on Linux, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cuda]
|
||||||
```
|
```
|
||||||
conda install -c conda-forge mlx
|
|
||||||
|
To install a CPU-only Linux package, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install mlx[cpu]
|
||||||
```
|
```
|
||||||
|
|
||||||
Checkout the
|
Checkout the
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
|||||||
@@ -192,6 +192,22 @@ void time_reductions() {
|
|||||||
|
|
||||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||||
TIME(argmin_along_1);
|
TIME(argmin_along_1);
|
||||||
|
|
||||||
|
auto indices = mx::array({1});
|
||||||
|
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||||
|
std::vector<int> axes{0};
|
||||||
|
auto b = scatter(a, {indices}, updates, axes);
|
||||||
|
mx::eval(b);
|
||||||
|
|
||||||
|
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||||
|
TIME(max_along_0);
|
||||||
|
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||||
|
TIME(max_along_1);
|
||||||
|
|
||||||
|
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||||
|
TIME(min_along_0);
|
||||||
|
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||||
|
TIME(min_along_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void time_gather_scatter() {
|
void time_gather_scatter() {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda
|
||||||
import torch.mps
|
import torch.mps
|
||||||
|
|
||||||
|
|
||||||
@@ -44,8 +45,10 @@ def bench(f, *args):
|
|||||||
|
|
||||||
|
|
||||||
def sync_if_needed(x):
|
def sync_if_needed(x):
|
||||||
if x.device != torch.device("cpu"):
|
if x.device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
elif x.device == torch.device("cuda"):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -99,6 +102,14 @@ def reduction(op, axis, x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sum_and_add(axis, x, y):
|
||||||
|
z = x.sum(axis=axis, keepdims=True)
|
||||||
|
for i in range(50):
|
||||||
|
z = (z + y).sum(axis=axis, keepdims=True)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
@@ -340,7 +351,11 @@ if __name__ == "__main__":
|
|||||||
args.axis.pop(0)
|
args.axis.pop(0)
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
device = "cpu" if args.cpu else "mps"
|
device = "mps"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
if args.cpu:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
types = args.dtype
|
types = args.dtype
|
||||||
if not types:
|
if not types:
|
||||||
@@ -460,5 +475,8 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "sum_and_add":
|
||||||
|
print(bench(sum_and_add, axis, *xs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||||
|
|||||||
107
benchmarks/python/conv_unaligned_bench.py
Normal file
107
benchmarks/python/conv_unaligned_bench.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
N_warmup = 10
|
||||||
|
N_iter_bench = 100
|
||||||
|
N_iter_func = 5
|
||||||
|
|
||||||
|
|
||||||
|
def bench(f, a, b):
|
||||||
|
for i in range(N_warmup):
|
||||||
|
f(a, b)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
s = time.perf_counter_ns()
|
||||||
|
for i in range(N_iter_bench):
|
||||||
|
f(a, b)
|
||||||
|
e = time.perf_counter_ns()
|
||||||
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
def mx_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
mx.eval(ys)
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
|
@torch.no_grad()
|
||||||
|
def pt_conv_2D(a, b):
|
||||||
|
ys = []
|
||||||
|
for i in range(N_iter_func):
|
||||||
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
|
ys.append(y)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
return ys
|
||||||
|
|
||||||
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
a_mx = mx.array(a_np)
|
||||||
|
b_mx = mx.array(b_np)
|
||||||
|
|
||||||
|
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||||
|
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
|
out_pt = torch.conv2d(
|
||||||
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
|
out_pt = out_pt.numpy(force=True)
|
||||||
|
|
||||||
|
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||||
|
|
||||||
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
|
print(
|
||||||
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return time_mlx, time_torch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dtype = "float32"
|
||||||
|
shapes = (
|
||||||
|
(4, 32, 32, 21, 3, 3, 128),
|
||||||
|
(4, 32, 32, 21, 3, 3, 37),
|
||||||
|
(4, 32, 32, 370, 3, 3, 370),
|
||||||
|
(4, 32, 32, 370, 7, 7, 128),
|
||||||
|
(2, 320, 640, 21, 7, 7, 21),
|
||||||
|
)
|
||||||
|
for N, H, W, C, kh, kw, O in shapes:
|
||||||
|
time_mlx, time_torch = bench_shape(
|
||||||
|
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||||
|
)
|
||||||
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||||
|
)
|
||||||
|
if time_mlx >= 2.0 * time_torch:
|
||||||
|
print("ATTENTION ^^^^^^^")
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from time import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
74
benchmarks/python/gather_mm_bench.py
Normal file
74
benchmarks/python/gather_mm_bench.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_mm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = x @ w1.T
|
||||||
|
x = x @ w2.T
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_mm()
|
||||||
84
benchmarks/python/gather_qmm_bench.py
Normal file
84
benchmarks/python/gather_qmm_bench.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from time_utils import time_fn
|
||||||
|
|
||||||
|
N = 1024
|
||||||
|
D = 1024
|
||||||
|
M = 1024
|
||||||
|
E = 32
|
||||||
|
I = 4
|
||||||
|
|
||||||
|
|
||||||
|
def gather_sort(x, indices):
|
||||||
|
N, M = indices.shape
|
||||||
|
indices = indices.flatten()
|
||||||
|
order = mx.argsort(indices)
|
||||||
|
inv_order = mx.argsort(order)
|
||||||
|
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_unsort(x, inv_order, shape=None):
|
||||||
|
x = x[inv_order]
|
||||||
|
if shape is not None:
|
||||||
|
x = mx.unflatten(x, 0, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gather_mm_simulate(x, w, indices):
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
for i in range(2):
|
||||||
|
y = mx.concatenate(
|
||||||
|
[
|
||||||
|
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||||
|
for i, j in enumerate(idx.tolist())
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
x = y[:, None]
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def time_gather_qmm():
|
||||||
|
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||||
|
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||||
|
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||||
|
|
||||||
|
def gather_mm(x, w1, w2, indices, sort):
|
||||||
|
idx = indices
|
||||||
|
inv_order = None
|
||||||
|
if sort:
|
||||||
|
x, idx, inv_order = gather_sort(x, indices)
|
||||||
|
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||||
|
if sort:
|
||||||
|
x = scatter_unsort(x, inv_order, indices.shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||||
|
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||||
|
|
||||||
|
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||||
|
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||||
|
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||||
|
w1 = mx.quantize(w1)
|
||||||
|
w2 = mx.quantize(w2)
|
||||||
|
mx.eval(x, w1, w2)
|
||||||
|
|
||||||
|
def equivalent_matmul(x, w1, w2):
|
||||||
|
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||||
|
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(equivalent_matmul, x, w1, w2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
time_gather_qmm()
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from time_utils import time_fn
|
from time_utils import time_fn
|
||||||
@@ -18,51 +20,63 @@ def layer_norm(x, w, b, eps):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def time_layer_norm():
|
def time_layer_norm(N, dt):
|
||||||
|
L = 1024
|
||||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(g, x, w, b):
|
def layer_norm_loop(f, x, w, b):
|
||||||
|
for _ in range(32):
|
||||||
|
x = f(x, w, b)
|
||||||
|
return x
|
||||||
|
|
||||||
|
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||||
|
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||||
|
|
||||||
|
def layer_norm_grad_loop(g, x, w, b):
|
||||||
gx, gw, gb = x, w, b
|
gx, gw, gb = x, w, b
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx, gw, gb = g(gx, gw, gb, y)
|
gx, gw, gb = g(gx, gw, gb, y)
|
||||||
return gx, gw, gb
|
return gx, gw, gb
|
||||||
|
|
||||||
time_fn(layer_norm_loop, g1, x, w, b)
|
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||||
time_fn(layer_norm_loop, g2, x, w, b)
|
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||||
|
|
||||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||||
g1 = mx.grad(f1, argnums=(0,))
|
g1 = mx.grad(f1, argnums=(0,))
|
||||||
g2 = mx.grad(f2, argnums=(0,))
|
g2 = mx.grad(f2, argnums=(0,))
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||||
mx.eval(x, w, b, y)
|
mx.eval(x, w, b, y)
|
||||||
|
|
||||||
def layer_norm_loop(g, x):
|
def layer_norm_grad_x_loop(g, x):
|
||||||
gx = x
|
gx = x
|
||||||
for _ in range(32):
|
for _ in range(32):
|
||||||
gx = g(gx, y)
|
gx = g(gx, y)
|
||||||
return gx
|
return gx
|
||||||
|
|
||||||
time_fn(layer_norm_loop, g1, x)
|
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||||
time_fn(layer_norm_loop, g2, x)
|
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
time_layer_norm()
|
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||||
|
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||||
|
print(dt, n)
|
||||||
|
time_layer_norm(n, dt)
|
||||||
|
|||||||
@@ -28,11 +28,34 @@ def bench(f, *args):
|
|||||||
return (e - s) * 1e-9
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
def mlx_sdpa_fused_inner(q, k, v, scale):
|
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
np_dtype = getattr(np, dtype)
|
||||||
|
|
||||||
|
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||||
|
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||||
|
|
||||||
|
scale = 1.0 / math.sqrt(D)
|
||||||
|
|
||||||
|
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||||
|
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||||
|
|
||||||
|
q_mx = mx.array(q_np)
|
||||||
|
k_mx = mx.array(k_np)
|
||||||
|
v_mx = mx.array(v_np)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask == "additive":
|
||||||
|
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
elif mask == "bool":
|
||||||
|
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||||
|
mask = mx.array(mask_np)
|
||||||
|
|
||||||
|
return q_mx, k_mx, v_mx, scale, mask
|
||||||
|
|
||||||
|
|
||||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||||
q_dtype = q.dtype
|
q_dtype = q.dtype
|
||||||
q = q * mx.array(scale, q_dtype)
|
q = q * mx.array(scale, q_dtype)
|
||||||
n_q_heads = q.shape[-3]
|
n_q_heads = q.shape[-3]
|
||||||
@@ -41,6 +64,7 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
|||||||
|
|
||||||
B = q.shape[0]
|
B = q.shape[0]
|
||||||
L = q.shape[2]
|
L = q.shape[2]
|
||||||
|
kL = k.shape[2]
|
||||||
|
|
||||||
if n_repeats > 1:
|
if n_repeats > 1:
|
||||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||||
@@ -48,10 +72,27 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
|||||||
v = mx.expand_dims(v, 2)
|
v = mx.expand_dims(v, 2)
|
||||||
|
|
||||||
scores = q @ mx.swapaxes(k, -1, -2)
|
scores = q @ mx.swapaxes(k, -1, -2)
|
||||||
if f32softmax:
|
|
||||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
if mask is not None:
|
||||||
|
|
||||||
|
if mask == "causal":
|
||||||
|
q_offset = max(0, kL - L)
|
||||||
|
q_indices = mx.arange(q_offset, q_offset + L)
|
||||||
|
k_indices = mx.arange(kL)
|
||||||
|
mask = q_indices[:, None] >= k_indices[None]
|
||||||
|
|
||||||
|
if n_repeats > 1 and mask.ndim >= 3:
|
||||||
|
if mask.shape[-3] == 1:
|
||||||
|
mask = mx.expand_dims(mask, -3)
|
||||||
else:
|
else:
|
||||||
scores = mx.softmax(scores, axis=-1)
|
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||||
|
|
||||||
|
if mask.dtype == mx.bool_:
|
||||||
|
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||||
|
else:
|
||||||
|
scores += mask
|
||||||
|
|
||||||
|
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||||
|
|
||||||
out = scores @ v
|
out = scores @ v
|
||||||
if n_repeats > 1:
|
if n_repeats > 1:
|
||||||
@@ -60,74 +101,55 @@ def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def mlx_spda_unfused(q, k, v, scale, transpose):
|
def mlx_fused_attn(q, k, v, scale, mask):
|
||||||
q_out = q
|
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
if transpose:
|
if transpose:
|
||||||
k = mx.transpose(k, (0, 2, 1, 3))
|
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||||
v = mx.transpose(v, (0, 2, 1, 3))
|
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||||
|
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||||
|
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||||
|
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||||
|
else:
|
||||||
|
return f(q, k, v, scale=scale, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||||
|
q_out = q
|
||||||
|
|
||||||
for i in range(N_iter_func):
|
for i in range(N_iter_func):
|
||||||
if transpose:
|
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
|
||||||
if transpose:
|
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
mx.eval(q_out)
|
mx.eval(q_out)
|
||||||
return q_out
|
return q_out
|
||||||
|
|
||||||
|
|
||||||
def mlx_spda_fused(q, k, v, scale, transpose):
|
def bench_shape(
|
||||||
q_out = q
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||||
if transpose:
|
):
|
||||||
k = mx.transpose(k, (0, 2, 1, 3))
|
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||||
v = mx.transpose(v, (0, 2, 1, 3))
|
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||||
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
if transpose:
|
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
|
||||||
if transpose:
|
|
||||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
mx.eval(q_out)
|
|
||||||
return q_out
|
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
|
||||||
shape_q = (
|
|
||||||
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
|
||||||
)
|
|
||||||
shape_kv = (
|
|
||||||
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
time_mlx_unfused = bench(
|
||||||
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
)
|
||||||
|
time_mlx_fused = bench(
|
||||||
|
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
|
||||||
scale = math.sqrt(1.0 / head_dim)
|
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||||
|
o_mlx_unfused = do_attention(
|
||||||
|
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||||
|
)
|
||||||
|
|
||||||
q_mx = mx.array(q_np)
|
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||||
k_mx = mx.array(k_np)
|
|
||||||
v_mx = mx.array(v_np)
|
|
||||||
|
|
||||||
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||||
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
|
||||||
|
|
||||||
if transpose:
|
|
||||||
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
|
||||||
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
|
||||||
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
|
||||||
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
|
||||||
|
|
||||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
|
||||||
|
|
||||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
|
||||||
print(
|
print(
|
||||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return time_mlx_fused, time_mlx_unfused
|
return time_mlx_fused, time_mlx_unfused
|
||||||
@@ -151,39 +173,51 @@ if __name__ == "__main__":
|
|||||||
( 1, 128, 128, 64, 32, 32),
|
( 1, 128, 128, 64, 32, 32),
|
||||||
( 1, 256, 256, 64, 32, 32),
|
( 1, 256, 256, 64, 32, 32),
|
||||||
( 1, 512, 512, 64, 32, 32),
|
( 1, 512, 512, 64, 32, 32),
|
||||||
( 1, 1024, 1024, 64, 32, 32),
|
( 1, 1024, 1024, 64, 32, 8),
|
||||||
( 1, 2048, 2048, 64, 32, 32),
|
( 1, 2048, 2048, 64, 32, 8),
|
||||||
( 1, 4096, 4096, 64, 32, 32),
|
( 1, 4096, 4096, 64, 32, 8),
|
||||||
)
|
)
|
||||||
|
|
||||||
shapes_80 = (
|
shapes_80 = (
|
||||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
( 1, 1024, 1024, 80, 32, 32),
|
( 1, 1024, 1024, 80, 32, 8),
|
||||||
( 1, 2048, 2048, 80, 32, 32),
|
( 1, 2048, 2048, 80, 32, 8),
|
||||||
( 1, 4096, 4096, 80, 32, 32),
|
( 1, 4096, 4096, 80, 32, 8),
|
||||||
)
|
)
|
||||||
|
|
||||||
shapes_128 = (
|
shapes_128 = (
|
||||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||||
( 1, 1024, 1024, 128, 32, 32),
|
( 1, 1024, 1024, 128, 32, 8),
|
||||||
( 1, 2048, 2048, 128, 32, 32),
|
( 1, 2048, 2048, 128, 32, 8),
|
||||||
( 1, 4096, 4096, 128, 32, 32),
|
( 1, 4096, 4096, 128, 32, 8),
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
shapes = shapes_64 + shapes_80 + shapes_128
|
shapes = shapes_64 + shapes_80 + shapes_128
|
||||||
|
|
||||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
masks = [None, "bool", "causal"]
|
||||||
|
|
||||||
|
print(
|
||||||
|
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||||
|
)
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
for transpose in transposes:
|
for transpose in transposes:
|
||||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||||
np_dtype = getattr(np, dtype)
|
for mask_in in masks:
|
||||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
B,
|
||||||
|
qsl,
|
||||||
|
ksl,
|
||||||
|
head_dim,
|
||||||
|
n_q_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
dtype,
|
||||||
|
transpose,
|
||||||
|
mask_in,
|
||||||
)
|
)
|
||||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||||
t_str = 1 if transpose else 0
|
t_str = 1 if transpose else 0
|
||||||
print(
|
print(
|
||||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -51,6 +51,20 @@ def time_maximum():
|
|||||||
time_fn(mx.maximum, a, b)
|
time_fn(mx.maximum, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def time_max():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.max, a, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def time_min():
|
||||||
|
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||||
|
a[1, 1] = mx.nan
|
||||||
|
mx.eval(a)
|
||||||
|
time_fn(mx.min, a, 0)
|
||||||
|
|
||||||
|
|
||||||
def time_negative():
|
def time_negative():
|
||||||
a = mx.random.uniform(shape=(10000, 1000))
|
a = mx.random.uniform(shape=(10000, 1000))
|
||||||
mx.eval(a)
|
mx.eval(a)
|
||||||
@@ -108,6 +122,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
time_add()
|
time_add()
|
||||||
time_matmul()
|
time_matmul()
|
||||||
|
time_min()
|
||||||
|
time_max()
|
||||||
time_maximum()
|
time_maximum()
|
||||||
time_exp()
|
time_exp()
|
||||||
time_negative()
|
time_negative()
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ include(CMakeParseArguments)
|
|||||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||||
# files (like headers)
|
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||||
|
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||||
#
|
#
|
||||||
# clang format on
|
# clang format on
|
||||||
|
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
|
|
||||||
@@ -26,6 +27,10 @@ macro(mlx_build_metallib)
|
|||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||||
|
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||||
|
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||||
|
-frecord-sources)
|
||||||
|
endif()
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
|
|||||||
CREATE_SUBDIRS = NO
|
CREATE_SUBDIRS = NO
|
||||||
FULL_PATH_NAMES = YES
|
FULL_PATH_NAMES = YES
|
||||||
RECURSIVE = YES
|
RECURSIVE = YES
|
||||||
GENERATE_HTML = YES
|
GENERATE_HTML = NO
|
||||||
GENERATE_LATEX = NO
|
GENERATE_LATEX = NO
|
||||||
GENERATE_XML = YES
|
GENERATE_XML = YES
|
||||||
XML_PROGRAMLISTING = YES
|
XML_PROGRAMLISTING = YES
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
sphinx
|
sphinx
|
||||||
breathe
|
breathe
|
||||||
sphinx-book-theme
|
sphinx-book-theme
|
||||||
|
sphinx-copybutton
|
||||||
mlx
|
mlx
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import mlx.core as mx
|
|||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "MLX"
|
project = "MLX"
|
||||||
copyright = "2023, MLX Contributors"
|
copyright = "2023, Apple"
|
||||||
author = "MLX Contributors"
|
author = "MLX Contributors"
|
||||||
version = ".".join(mx.__version__.split(".")[:3])
|
version = ".".join(mx.__version__.split(".")[:3])
|
||||||
release = version
|
release = version
|
||||||
@@ -18,6 +18,7 @@ release = version
|
|||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
|
"sphinx_copybutton",
|
||||||
"sphinx.ext.autodoc",
|
"sphinx.ext.autodoc",
|
||||||
"sphinx.ext.autosummary",
|
"sphinx.ext.autosummary",
|
||||||
"sphinx.ext.intersphinx",
|
"sphinx.ext.intersphinx",
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
|||||||
Simple Example
|
Simple Example
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
T tmp = inp[elem];
|
T tmp = inp[elem];
|
||||||
@@ -25,6 +26,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -39,8 +42,13 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
|||||||
b = exp_elementwise(a)
|
b = exp_elementwise(a)
|
||||||
assert mx.allclose(b, mx.exp(a))
|
assert mx.allclose(b, mx.exp(a))
|
||||||
|
|
||||||
|
Every time you make a kernel, a new Metal library is created and possibly
|
||||||
|
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||||
|
:func:`fast.metal_kernel` and then use it many times.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
We are only required to pass the body of the Metal kernel in ``source``.
|
Only pass the body of the Metal kernel in ``source``. The function
|
||||||
|
signature is generated automatically.
|
||||||
|
|
||||||
The full function signature will be generated using:
|
The full function signature will be generated using:
|
||||||
|
|
||||||
@@ -78,29 +86,34 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
|||||||
|
|
||||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||||
|
|
||||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||||
|
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||||
|
dimension should be less than or equal to the corresponding grid dimension.
|
||||||
|
|
||||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||||
|
generated code for debugging purposes.
|
||||||
|
|
||||||
Using Shape/Strides
|
Using Shape/Strides
|
||||||
-------------------
|
-------------------
|
||||||
|
|
||||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
is ``True`` by default. This will copy the array inputs if needed
|
||||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
before the kernel is launched to ensure that the memory layout is row
|
||||||
when indexing.
|
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||||
|
have to worry about gaps or the ordering of the dims when indexing.
|
||||||
|
|
||||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||||
input array ``a`` if any are present in ``source``.
|
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||||
|
the right elements for each thread.
|
||||||
|
|
||||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||||
|
relying on a copy from ``ensure_row_contiguous``:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||||
@@ -116,6 +129,8 @@ Let's convert ``myexp`` above to support arbitrarily strided arrays without rely
|
|||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source
|
source=source
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def exp_elementwise(a: mx.array):
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[a],
|
inputs=[a],
|
||||||
template=[("T", mx.float32)],
|
template=[("T", mx.float32)],
|
||||||
@@ -183,25 +198,13 @@ We'll start with the following MLX implementation using standard ops:
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||||
to write a fast GPU kernel for both the forward and backward passes.
|
to write a fast GPU kernel for both the forward and backward passes.
|
||||||
|
|
||||||
First we'll implement the forward pass as a fused kernel:
|
First we'll implement the forward pass as a fused kernel:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@mx.custom_function
|
|
||||||
def grid_sample(x, grid):
|
|
||||||
|
|
||||||
assert x.ndim == 4, "`x` must be 4D."
|
|
||||||
assert grid.ndim == 4, "`grid` must be 4D."
|
|
||||||
|
|
||||||
B, _, _, C = x.shape
|
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
out_shape = (B, gN, gM, C)
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
|
||||||
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
int H = x_shape[1];
|
int H = x_shape[1];
|
||||||
@@ -251,12 +254,26 @@ First we'll implement the forward pass as a fused kernel:
|
|||||||
|
|
||||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
kernel = mx.fast.metal_kernel(
|
||||||
name="grid_sample",
|
name="grid_sample",
|
||||||
input_names=["x", "grid"],
|
input_names=["x", "grid"],
|
||||||
output_names=["out"],
|
output_names=["out"],
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@mx.custom_function
|
||||||
|
def grid_sample(x, grid):
|
||||||
|
|
||||||
|
assert x.ndim == 4, "`x` must be 4D."
|
||||||
|
assert grid.ndim == 4, "`grid` must be 4D."
|
||||||
|
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
out_shape = (B, gN, gM, C)
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
outputs = kernel(
|
outputs = kernel(
|
||||||
inputs=[x, grid],
|
inputs=[x, grid],
|
||||||
template=[("T", x.dtype)],
|
template=[("T", x.dtype)],
|
||||||
@@ -281,11 +298,11 @@ On an M1 Max, we see a big performance improvement:
|
|||||||
Grid Sample VJP
|
Grid Sample VJP
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||||
its custom vjp transform so MLX can differentiate it.
|
define its custom vjp transform so MLX can differentiate it.
|
||||||
|
|
||||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
requires a few extra :func:`fast.metal_kernel` features:
|
||||||
|
|
||||||
* ``init_value=0``
|
* ``init_value=0``
|
||||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||||
@@ -299,14 +316,6 @@ We can then implement the backwards pass as follows:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@grid_sample.vjp
|
|
||||||
def grid_sample_vjp(primals, cotangent, _):
|
|
||||||
x, grid = primals
|
|
||||||
B, _, _, C = x.shape
|
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
|
||||||
|
|
||||||
source = """
|
source = """
|
||||||
uint elem = thread_position_in_grid.x;
|
uint elem = thread_position_in_grid.x;
|
||||||
int H = x_shape[1];
|
int H = x_shape[1];
|
||||||
@@ -406,6 +415,15 @@ We can then implement the backwards pass as follows:
|
|||||||
source=source,
|
source=source,
|
||||||
atomic_outputs=True,
|
atomic_outputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@grid_sample.vjp
|
||||||
|
def grid_sample_vjp(primals, cotangent, _):
|
||||||
|
x, grid = primals
|
||||||
|
B, _, _, C = x.shape
|
||||||
|
_, gN, gM, D = grid.shape
|
||||||
|
|
||||||
|
assert D == 2, "Last dim of `grid` must be size 2."
|
||||||
|
|
||||||
# pad the output channels to simd group size
|
# pad the output channels to simd group size
|
||||||
# so that our `simd_sum`s don't overlap.
|
# so that our `simd_sum`s don't overlap.
|
||||||
simdgroup_size = 32
|
simdgroup_size = 32
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ You can do that in MLX directly:
|
|||||||
This function performs that operation while leaving the implementation and
|
This function performs that operation while leaving the implementation and
|
||||||
function transformations to MLX.
|
function transformations to MLX.
|
||||||
|
|
||||||
However you may need to customize the underlying implementation, perhaps to
|
However, you may want to customize the underlying implementation, perhaps to
|
||||||
make it faster or for custom differentiation. In this tutorial we will go
|
make it faster. In this tutorial we will go through adding custom extensions.
|
||||||
through adding custom extensions. It will cover:
|
It will cover:
|
||||||
|
|
||||||
* The structure of the MLX library.
|
* The structure of the MLX library.
|
||||||
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
* Implementing a CPU operation.
|
||||||
* Implementing a GPU operation using metal.
|
* Implementing a GPU operation using metal.
|
||||||
* Adding the ``vjp`` and ``jvp`` function transformation.
|
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||||
* Building a custom extension and binding it to python.
|
* Building a custom extension and binding it to python.
|
||||||
@@ -45,7 +45,7 @@ Operations
|
|||||||
Operations are the front-end functions that operate on arrays. They are defined
|
Operations are the front-end functions that operate on arrays. They are defined
|
||||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||||
|
|
||||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
|
||||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||||
C++:
|
C++:
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ C++:
|
|||||||
* Scale and sum two vectors element-wise
|
* Scale and sum two vectors element-wise
|
||||||
* z = alpha * x + beta * y
|
* z = alpha * x + beta * y
|
||||||
*
|
*
|
||||||
* Follow numpy style broadcasting between x and y
|
* Use NumPy-style broadcasting between x and y
|
||||||
* Inputs are upcasted to floats if needed
|
* Inputs are upcasted to floats if needed
|
||||||
**/
|
**/
|
||||||
array axpby(
|
array axpby(
|
||||||
@@ -66,7 +66,7 @@ C++:
|
|||||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||||
);
|
);
|
||||||
|
|
||||||
The simplest way to this operation is in terms of existing operations:
|
The simplest way to implement this is with existing operations:
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
@@ -93,9 +93,9 @@ Primitives
|
|||||||
^^^^^^^^^^^
|
^^^^^^^^^^^
|
||||||
|
|
||||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||||
defines how to create outputs arrays given a input arrays. Further, a
|
defines how to create output arrays given input arrays. Further, a
|
||||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||||
more concrete:
|
more concrete:
|
||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
@@ -128,7 +128,7 @@ more concrete:
|
|||||||
/** The vector-Jacobian product. */
|
/** The vector-Jacobian product. */
|
||||||
std::vector<array> vjp(
|
std::vector<array> vjp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const array& cotan,
|
const std::vector<array>& cotangents,
|
||||||
const std::vector<int>& argnums,
|
const std::vector<int>& argnums,
|
||||||
const std::vector<array>& outputs) override;
|
const std::vector<array>& outputs) override;
|
||||||
|
|
||||||
@@ -138,13 +138,13 @@ more concrete:
|
|||||||
* representing the vectorized computation and the axis which
|
* representing the vectorized computation and the axis which
|
||||||
* corresponds to the output vectorized dimension.
|
* corresponds to the output vectorized dimension.
|
||||||
*/
|
*/
|
||||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
@@ -153,9 +153,6 @@ more concrete:
|
|||||||
private:
|
private:
|
||||||
float alpha_;
|
float alpha_;
|
||||||
float beta_;
|
float beta_;
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||||
@@ -188,7 +185,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
|||||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||||
|
|
||||||
// Upcast to float32 for non-floating point inputs x and y
|
// Upcast to float32 for non-floating point inputs x and y
|
||||||
auto out_dtype = is_floating_point(promoted_dtype)
|
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||||
? promoted_dtype
|
? promoted_dtype
|
||||||
: promote_types(promoted_dtype, float32);
|
: promote_types(promoted_dtype, float32);
|
||||||
|
|
||||||
@@ -234,11 +231,9 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
|||||||
Implementing the CPU Back-end
|
Implementing the CPU Back-end
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Let's start by implementing a naive and generic version of
|
Let's start by implementing :meth:`Axpby::eval_cpu`.
|
||||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
|
||||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
|
||||||
|
|
||||||
Our naive method will go over each element of the output array, find the
|
The method will go over each element of the output array, find the
|
||||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||||
|
|
||||||
@@ -246,36 +241,46 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void axpby_impl(
|
void axpby_impl(
|
||||||
const array& x,
|
const mx::array& x,
|
||||||
const array& y,
|
const mx::array& y,
|
||||||
array& out,
|
mx::array& out,
|
||||||
float alpha_,
|
float alpha_,
|
||||||
float beta_) {
|
float beta_,
|
||||||
// We only allocate memory when we are ready to fill the output
|
mx::Stream stream) {
|
||||||
// malloc_or_wait synchronously allocates available memory
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
// There may be a wait executed here if the allocation is requested
|
|
||||||
// under memory-pressured conditions
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// Collect input and output data pointers
|
// Get the CPU command encoder and register input and output arrays
|
||||||
const T* x_ptr = x.data<T>();
|
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||||
const T* y_ptr = y.data<T>();
|
encoder.set_input_array(x);
|
||||||
T* out_ptr = out.data<T>();
|
encoder.set_input_array(y);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
// Launch the CPU kernel
|
||||||
|
encoder.dispatch([x_ptr = x.data<T>(),
|
||||||
|
y_ptr = y.data<T>(),
|
||||||
|
out_ptr = out.data<T>(),
|
||||||
|
size = out.size(),
|
||||||
|
shape = out.shape(),
|
||||||
|
x_strides = x.strides(),
|
||||||
|
y_strides = y.strides(),
|
||||||
|
alpha_,
|
||||||
|
beta_]() {
|
||||||
|
|
||||||
// Cast alpha and beta to the relevant types
|
// Cast alpha and beta to the relevant types
|
||||||
T alpha = static_cast<T>(alpha_);
|
T alpha = static_cast<T>(alpha_);
|
||||||
T beta = static_cast<T>(beta_);
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
// Do the element-wise operation for each output
|
// Do the element-wise operation for each output
|
||||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||||
// Map linear indices to offsets in x and y
|
// Map linear indices to offsets in x and y
|
||||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||||
|
|
||||||
// We allocate the output to be contiguous and regularly strided
|
// We allocate the output to be contiguous and regularly strided
|
||||||
// (defaults to row major) and hence it doesn't need additional mapping
|
// (defaults to row major) and hence it doesn't need additional mapping
|
||||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Our implementation should work for all incoming floating point arrays.
|
Our implementation should work for all incoming floating point arrays.
|
||||||
@@ -284,112 +289,32 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
|||||||
|
|
||||||
.. code-block:: C++
|
.. code-block:: C++
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
void Axpby::eval_cpu(
|
||||||
void Axpby::eval(
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<array>& inputs,
|
std::vector<mx::array>& outputs) {
|
||||||
const std::vector<array>& outputs) {
|
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Dispatch to the correct dtype
|
// Dispatch to the correct dtype
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == mx::float32) {
|
||||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == float16) {
|
} else if (out.dtype() == mx::float16) {
|
||||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == bfloat16) {
|
} else if (out.dtype() == mx::bfloat16) {
|
||||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == complex64) {
|
} else if (out.dtype() == mx::complex64) {
|
||||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[Axpby] Only supports floating point types.");
|
"Axpby is only supported for floating point types.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
This is good as a fallback implementation. We can use the ``axpby`` routine
|
|
||||||
provided by the Accelerate_ framework for a faster implementation in certain
|
|
||||||
cases:
|
|
||||||
|
|
||||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
|
||||||
floats. We can only use it for ``float32`` types.
|
|
||||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
|
||||||
elements have fixed strides between them. We only direct to Accelerate
|
|
||||||
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
|
||||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
|
||||||
MLX expects to write the output to a new array. We must copy the elements
|
|
||||||
of ``y`` into the output and use that as an input to ``axpby``.
|
|
||||||
|
|
||||||
Let's write an implementation that uses Accelerate in the right conditions.
|
|
||||||
It allocates data for the output, copies ``y`` into it, and then calls the
|
|
||||||
:func:`catlas_saxpby` from accelerate.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void axpby_impl_accelerate(
|
|
||||||
const array& x,
|
|
||||||
const array& y,
|
|
||||||
array& out,
|
|
||||||
float alpha_,
|
|
||||||
float beta_) {
|
|
||||||
// Accelerate library provides catlas_saxpby which does
|
|
||||||
// Y = (alpha * X) + (beta * Y) in place
|
|
||||||
// To use it, we first copy the data in y over to the output array
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// We then copy over the elements using the contiguous vector specialization
|
|
||||||
copy_inplace(y, out, CopyType::Vector);
|
|
||||||
|
|
||||||
// Get x and y pointers for catlas_saxpby
|
|
||||||
const T* x_ptr = x.data<T>();
|
|
||||||
T* y_ptr = out.data<T>();
|
|
||||||
|
|
||||||
T alpha = static_cast<T>(alpha_);
|
|
||||||
T beta = static_cast<T>(beta_);
|
|
||||||
|
|
||||||
// Call the inplace accelerate operator
|
|
||||||
catlas_saxpby(
|
|
||||||
/* N = */ out.size(),
|
|
||||||
/* ALPHA = */ alpha,
|
|
||||||
/* X = */ x_ptr,
|
|
||||||
/* INCX = */ 1,
|
|
||||||
/* BETA = */ beta,
|
|
||||||
/* Y = */ y_ptr,
|
|
||||||
/* INCY = */ 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
For inputs that do not fit the criteria for accelerate, we fall back to
|
|
||||||
:meth:`Axpby::eval`. With this in mind, let's finish our
|
|
||||||
:meth:`Axpby::eval_cpu`.
|
|
||||||
|
|
||||||
.. code-block:: C++
|
|
||||||
|
|
||||||
/** Evaluate primitive on CPU using accelerate specializations */
|
|
||||||
void Axpby::eval_cpu(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& x = inputs[0];
|
|
||||||
auto& y = inputs[1];
|
|
||||||
auto& out = outputs[0];
|
|
||||||
|
|
||||||
// Accelerate specialization for contiguous single precision float arrays
|
|
||||||
if (out.dtype() == float32 &&
|
|
||||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
|
||||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
|
||||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to common back-end if specializations are not available
|
|
||||||
eval(inputs, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||||
you do not plan on running the operation on the GPU or using transforms on
|
you do not plan on running the operation on the GPU or using transforms on
|
||||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
primitive here.
|
||||||
|
|
||||||
Implementing the GPU Back-end
|
Implementing the GPU Back-end
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
@@ -466,17 +391,17 @@ below.
|
|||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
// Allocate output memory
|
// Allocate output memory
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
// Resolve name of kernel
|
// Resolve name of kernel
|
||||||
std::ostringstream kname;
|
std::stream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname = "axpby_general_" + type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Load the metal library
|
||||||
d.register_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname, lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
@@ -544,7 +469,7 @@ one we just defined:
|
|||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
// Forward mode diff that pushes along the tangents
|
// Forward mode diff that pushes along the tangents
|
||||||
// The jvp transform on the primitive can built with ops
|
// The jvp transform on the primitive can be built with ops
|
||||||
// that are scheduled on the same stream as the primitive
|
// that are scheduled on the same stream as the primitive
|
||||||
|
|
||||||
// If argnums = {0}, we only push along x in which case the
|
// If argnums = {0}, we only push along x in which case the
|
||||||
@@ -556,7 +481,7 @@ one we just defined:
|
|||||||
auto scale_arr = array(scale, tangents[0].dtype());
|
auto scale_arr = array(scale, tangents[0].dtype());
|
||||||
return {multiply(scale_arr, tangents[0], stream())};
|
return {multiply(scale_arr, tangents[0], stream())};
|
||||||
}
|
}
|
||||||
// If, argnums = {0, 1}, we take contributions from both
|
// If argnums = {0, 1}, we take contributions from both
|
||||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||||
else {
|
else {
|
||||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||||
@@ -810,7 +735,7 @@ Let's look at a simple script and its results:
|
|||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c.dtype}")
|
||||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
|
|
||||||
@@ -818,13 +743,13 @@ Output:
|
|||||||
|
|
||||||
c shape: [3, 4]
|
c shape: [3, 4]
|
||||||
c dtype: float32
|
c dtype: float32
|
||||||
c correctness: True
|
c is correct: True
|
||||||
|
|
||||||
Results
|
Results
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||||
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
with the naive :meth:`simple_axpby` we first defined.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@@ -832,13 +757,11 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
|
|||||||
from mlx_sample_extensions import axpby
|
from mlx_sample_extensions import axpby
|
||||||
import time
|
import time
|
||||||
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
|
|
||||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||||
return alpha * x + beta * y
|
return alpha * x + beta * y
|
||||||
|
|
||||||
M = 256
|
M = 4096
|
||||||
N = 512
|
N = 4096
|
||||||
|
|
||||||
x = mx.random.normal((M, N))
|
x = mx.random.normal((M, N))
|
||||||
y = mx.random.normal((M, N))
|
y = mx.random.normal((M, N))
|
||||||
@@ -849,24 +772,24 @@ with the naive :meth:`simple_axpby` we first defined on the CPU.
|
|||||||
|
|
||||||
def bench(f):
|
def bench(f):
|
||||||
# Warm up
|
# Warm up
|
||||||
for i in range(100):
|
for i in range(5):
|
||||||
z = f(x, y, alpha, beta)
|
z = f(x, y, alpha, beta)
|
||||||
mx.eval(z)
|
mx.eval(z)
|
||||||
|
|
||||||
# Timed run
|
# Timed run
|
||||||
s = time.time()
|
s = time.time()
|
||||||
for i in range(5000):
|
for i in range(100):
|
||||||
z = f(x, y, alpha, beta)
|
z = f(x, y, alpha, beta)
|
||||||
mx.eval(z)
|
mx.eval(z)
|
||||||
e = time.time()
|
e = time.time()
|
||||||
return e - s
|
return 1000 * (e - s) / 100
|
||||||
|
|
||||||
simple_time = bench(simple_axpby)
|
simple_time = bench(simple_axpby)
|
||||||
custom_time = bench(axpby)
|
custom_time = bench(axpby)
|
||||||
|
|
||||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
|
||||||
|
|
||||||
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
|
||||||
modest improvements right away!
|
modest improvements right away!
|
||||||
|
|
||||||
This operation is now good to be used to build other operations, in
|
This operation is now good to be used to build other operations, in
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ are the CPU and GPU.
|
|||||||
python/fft
|
python/fft
|
||||||
python/linalg
|
python/linalg
|
||||||
python/metal
|
python/metal
|
||||||
|
python/memory_management
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
python/distributed
|
python/distributed
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ silicon computer is
|
|||||||
|
|
||||||
pip install mlx
|
pip install mlx
|
||||||
|
|
||||||
To install from PyPI you must meet the following requirements:
|
To install from PyPI your system must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.9
|
- Using a native Python >= 3.9
|
||||||
@@ -23,12 +23,39 @@ To install from PyPI you must meet the following requirements:
|
|||||||
MLX is only available on devices running macOS >= 13.5
|
MLX is only available on devices running macOS >= 13.5
|
||||||
It is highly recommended to use macOS 14 (Sonoma)
|
It is highly recommended to use macOS 14 (Sonoma)
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
MLX is also available on conda-forge. To install MLX with conda do:
|
MLX has a CUDA backend which you can install with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
conda install conda-forge::mlx
|
pip install mlx[cuda]
|
||||||
|
|
||||||
|
To install the CUDA package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Nvidia architecture >= SM 7.0 (Volta)
|
||||||
|
- Nvidia driver >= 550.54.14
|
||||||
|
- CUDA toolkit >= 12.0
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.9
|
||||||
|
|
||||||
|
|
||||||
|
CPU-only (Linux)
|
||||||
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
For a CPU-only version of MLX that runs on Linux use:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install mlx[cpu]
|
||||||
|
|
||||||
|
To install the CPU-only package from PyPi your system must meet the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Linux distribution with glibc >= 2.35
|
||||||
|
- Python >= 3.9
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
@@ -65,6 +92,8 @@ Build Requirements
|
|||||||
Python API
|
Python API
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
.. _python install:
|
||||||
|
|
||||||
To build and install the MLX python library from source, first, clone MLX from
|
To build and install the MLX python library from source, first, clone MLX from
|
||||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||||
|
|
||||||
@@ -76,20 +105,20 @@ Then simply build and install MLX using pip:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
pip install .
|
||||||
|
|
||||||
For developing, install the package with development dependencies, and use an
|
For developing, install the package with development dependencies, and use an
|
||||||
editable install:
|
editable install:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
Once the development dependencies are installed, you can build faster with:
|
Once the development dependencies are installed, you can build faster with:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
python setup.py build_ext --inplace
|
||||||
|
|
||||||
Run the tests with:
|
Run the tests with:
|
||||||
|
|
||||||
@@ -107,6 +136,8 @@ IDE:
|
|||||||
C++ API
|
C++ API
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
.. _cpp install:
|
||||||
|
|
||||||
Currently, MLX must be built and installed from source.
|
Currently, MLX must be built and installed from source.
|
||||||
|
|
||||||
Similarly to the python library, to build and install the MLX C++ library start
|
Similarly to the python library, to build and install the MLX C++ library start
|
||||||
@@ -185,6 +216,7 @@ should point to the path to the built metal library.
|
|||||||
|
|
||||||
xcrun -sdk macosx --show-sdk-version
|
xcrun -sdk macosx --show-sdk-version
|
||||||
|
|
||||||
|
|
||||||
Binary Size Minimization
|
Binary Size Minimization
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -213,6 +245,50 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
|||||||
application. Once a kernel is compiled, it will be cached by the system. The
|
application. Once a kernel is compiled, it will be cached by the system. The
|
||||||
Metal kernel cache persists across reboots.
|
Metal kernel cache persists across reboots.
|
||||||
|
|
||||||
|
Linux
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||||
|
For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
apt-get update -y
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
From here follow the instructions to install either the :ref:`Python <python
|
||||||
|
install>` or :ref:`C++ <cpp install>` APIs.
|
||||||
|
|
||||||
|
CUDA
|
||||||
|
^^^^
|
||||||
|
|
||||||
|
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||||
|
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
apt-get update -y
|
||||||
|
apt-get -y install cuda-toolkit-12-9
|
||||||
|
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||||
|
|
||||||
|
|
||||||
|
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||||
|
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||||
|
|
||||||
|
To build the C++ package run:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
mkdir -p build && cd build
|
||||||
|
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||||
|
|
||||||
|
|
||||||
Troubleshooting
|
Troubleshooting
|
||||||
^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ Array
|
|||||||
array.ndim
|
array.ndim
|
||||||
array.shape
|
array.shape
|
||||||
array.size
|
array.size
|
||||||
|
array.real
|
||||||
|
array.imag
|
||||||
array.abs
|
array.abs
|
||||||
array.all
|
array.all
|
||||||
array.any
|
array.any
|
||||||
@@ -38,6 +40,7 @@ Array
|
|||||||
array.log10
|
array.log10
|
||||||
array.log1p
|
array.log1p
|
||||||
array.log2
|
array.log2
|
||||||
|
array.logcumsumexp
|
||||||
array.logsumexp
|
array.logsumexp
|
||||||
array.max
|
array.max
|
||||||
array.mean
|
array.mean
|
||||||
|
|||||||
@@ -20,3 +20,5 @@ FFT
|
|||||||
irfft2
|
irfft2
|
||||||
rfftn
|
rfftn
|
||||||
irfftn
|
irfftn
|
||||||
|
fftshift
|
||||||
|
ifftshift
|
||||||
|
|||||||
@@ -16,9 +16,12 @@ Linear Algebra
|
|||||||
cross
|
cross
|
||||||
qr
|
qr
|
||||||
svd
|
svd
|
||||||
|
eigvals
|
||||||
|
eig
|
||||||
eigvalsh
|
eigvalsh
|
||||||
eigh
|
eigh
|
||||||
lu
|
lu
|
||||||
lu_factor
|
lu_factor
|
||||||
|
pinv
|
||||||
solve
|
solve
|
||||||
solve_triangular
|
solve_triangular
|
||||||
|
|||||||
16
docs/src/python/memory_management.rst
Normal file
16
docs/src/python/memory_management.rst
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
Memory Management
|
||||||
|
=================
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
get_active_memory
|
||||||
|
get_peak_memory
|
||||||
|
reset_peak_memory
|
||||||
|
get_cache_memory
|
||||||
|
set_memory_limit
|
||||||
|
set_cache_limit
|
||||||
|
set_wired_limit
|
||||||
|
clear_cache
|
||||||
@@ -8,13 +8,5 @@ Metal
|
|||||||
|
|
||||||
is_available
|
is_available
|
||||||
device_info
|
device_info
|
||||||
get_active_memory
|
|
||||||
get_peak_memory
|
|
||||||
reset_peak_memory
|
|
||||||
get_cache_memory
|
|
||||||
set_memory_limit
|
|
||||||
set_cache_limit
|
|
||||||
set_wired_limit
|
|
||||||
clear_cache
|
|
||||||
start_capture
|
start_capture
|
||||||
stop_capture
|
stop_capture
|
||||||
|
|||||||
@@ -36,10 +36,12 @@ Operations
|
|||||||
bitwise_or
|
bitwise_or
|
||||||
bitwise_xor
|
bitwise_xor
|
||||||
block_masked_mm
|
block_masked_mm
|
||||||
|
broadcast_arrays
|
||||||
broadcast_to
|
broadcast_to
|
||||||
ceil
|
ceil
|
||||||
clip
|
clip
|
||||||
concatenate
|
concatenate
|
||||||
|
contiguous
|
||||||
conj
|
conj
|
||||||
conjugate
|
conjugate
|
||||||
convolve
|
convolve
|
||||||
@@ -101,6 +103,7 @@ Operations
|
|||||||
log10
|
log10
|
||||||
log1p
|
log1p
|
||||||
logaddexp
|
logaddexp
|
||||||
|
logcumsumexp
|
||||||
logical_not
|
logical_not
|
||||||
logical_and
|
logical_and
|
||||||
logical_or
|
logical_or
|
||||||
|
|||||||
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
|||||||
optimizer.update(model, grads)
|
optimizer.update(model, grads)
|
||||||
|
|
||||||
# Save the state
|
# Save the state
|
||||||
state = tree_flatten(optimizer.state)
|
state = tree_flatten(optimizer.state, destination={})
|
||||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
mx.save_safetensors("optimizer.safetensors", state)
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
# Later on, for example when loading from a checkpoint,
|
||||||
# recreate the optimizer and load the state
|
# recreate the optimizer and load the state
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
optimizer = optim.Adam(learning_rate=1e-2)
|
||||||
|
|
||||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||||
optimizer.state = state
|
optimizer.state = state
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
Note, not every optimizer configuation parameter is saved in the state. For
|
||||||
|
|||||||
@@ -18,3 +18,5 @@ Common Optimizers
|
|||||||
AdamW
|
AdamW
|
||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
|
MultiOptimizer
|
||||||
|
Muon
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ Transforms
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
eval
|
eval
|
||||||
|
async_eval
|
||||||
compile
|
compile
|
||||||
custom_function
|
custom_function
|
||||||
disable_compile
|
disable_compile
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
|||||||
model.update(tree_unflatten(list(params.items())))
|
model.update(tree_unflatten(list(params.items())))
|
||||||
return model(x)
|
return model(x)
|
||||||
|
|
||||||
params = dict(tree_flatten(model.parameters()))
|
params = tree_flatten(model.parameters(), destination={})
|
||||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -107,6 +107,16 @@ same array:
|
|||||||
>>> a
|
>>> a
|
||||||
array([1, 2, 0], dtype=int32)
|
array([1, 2, 0], dtype=int32)
|
||||||
|
|
||||||
|
|
||||||
|
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
>>> a = mx.array([1, 2, 3])
|
||||||
|
>>> a[[0, 0]] = mx.array([4, 5])
|
||||||
|
|
||||||
|
The first element of ``a`` could be ``4`` or ``5``.
|
||||||
|
|
||||||
Transformations of functions which use in-place updates are allowed and work as
|
Transformations of functions which use in-place updates are allowed and work as
|
||||||
expected. For example:
|
expected. For example:
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|||||||
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||||
|
|
||||||
# ----------------------------- Dependencies -----------------------------
|
# ----------------------------- Dependencies -----------------------------
|
||||||
find_package(MLX CONFIG REQUIRED)
|
|
||||||
find_package(
|
find_package(
|
||||||
Python 3.8
|
Python 3.8
|
||||||
COMPONENTS Interpreter Development.Module
|
COMPONENTS Interpreter Development.Module
|
||||||
@@ -21,6 +20,12 @@ execute_process(
|
|||||||
OUTPUT_VARIABLE nanobind_ROOT)
|
OUTPUT_VARIABLE nanobind_ROOT)
|
||||||
find_package(nanobind CONFIG REQUIRED)
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
|
|
||||||
|
execute_process(
|
||||||
|
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
OUTPUT_VARIABLE MLX_ROOT)
|
||||||
|
find_package(MLX CONFIG REQUIRED)
|
||||||
|
|
||||||
# ----------------------------- Extensions -----------------------------
|
# ----------------------------- Extensions -----------------------------
|
||||||
|
|
||||||
# Add library
|
# Add library
|
||||||
|
|||||||
@@ -1,20 +1,15 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
#include <dlfcn.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/copy.h"
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include "axpby/axpby.h"
|
#include "axpby/axpby.h"
|
||||||
|
|
||||||
#ifdef ACCELERATE_NEW_LAPACK
|
|
||||||
#include <vecLib/cblas_new.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef _METAL_
|
#ifdef _METAL_
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
@@ -22,6 +17,19 @@
|
|||||||
|
|
||||||
namespace my_ext {
|
namespace my_ext {
|
||||||
|
|
||||||
|
// A helper function to find the location of the current binary on disk.
|
||||||
|
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
|
||||||
|
std::string current_binary_dir() {
|
||||||
|
static std::string binary_dir = []() {
|
||||||
|
Dl_info info;
|
||||||
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||||
|
throw std::runtime_error("Unable to get current binary dir.");
|
||||||
|
}
|
||||||
|
return std::filesystem::path(info.dli_fname).parent_path().string();
|
||||||
|
}();
|
||||||
|
return binary_dir;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Operation Implementation
|
// Operation Implementation
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -76,136 +84,65 @@ void axpby_impl(
|
|||||||
const mx::array& y,
|
const mx::array& y,
|
||||||
mx::array& out,
|
mx::array& out,
|
||||||
float alpha_,
|
float alpha_,
|
||||||
float beta_) {
|
float beta_,
|
||||||
// We only allocate memory when we are ready to fill the output
|
mx::Stream stream) {
|
||||||
// malloc_or_wait synchronously allocates available memory
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
// There may be a wait executed here if the allocation is requested
|
|
||||||
// under memory-pressured conditions
|
|
||||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// Collect input and output data pointers
|
// Get the CPU command encoder and register input and output arrays
|
||||||
const T* x_ptr = x.data<T>();
|
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||||
const T* y_ptr = y.data<T>();
|
encoder.set_input_array(x);
|
||||||
T* out_ptr = out.data<T>();
|
encoder.set_input_array(y);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
// Launch the CPU kernel
|
||||||
|
encoder.dispatch([x_ptr = x.data<T>(),
|
||||||
|
y_ptr = y.data<T>(),
|
||||||
|
out_ptr = out.data<T>(),
|
||||||
|
size = out.size(),
|
||||||
|
shape = out.shape(),
|
||||||
|
x_strides = x.strides(),
|
||||||
|
y_strides = y.strides(),
|
||||||
|
alpha_,
|
||||||
|
beta_]() {
|
||||||
// Cast alpha and beta to the relevant types
|
// Cast alpha and beta to the relevant types
|
||||||
T alpha = static_cast<T>(alpha_);
|
T alpha = static_cast<T>(alpha_);
|
||||||
T beta = static_cast<T>(beta_);
|
T beta = static_cast<T>(beta_);
|
||||||
|
|
||||||
// Do the element-wise operation for each output
|
// Do the element-wise operation for each output
|
||||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||||
// Map linear indices to offsets in x and y
|
// Map linear indices to offsets in x and y
|
||||||
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
|
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||||
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
|
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||||
|
|
||||||
// We allocate the output to be contiguous and regularly strided
|
// We allocate the output to be contiguous and regularly strided
|
||||||
// (defaults to row major) and hence it doesn't need additional mapping
|
// (defaults to row major) and hence it doesn't need additional mapping
|
||||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
void Axpby::eval_cpu(
|
||||||
void Axpby::eval(
|
|
||||||
const std::vector<mx::array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
std::vector<mx::array>& outputs) {
|
std::vector<mx::array>& outputs) {
|
||||||
// Check the inputs (registered in the op while constructing the out array)
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
|
|
||||||
// Dispatch to the correct dtype
|
// Dispatch to the correct dtype
|
||||||
if (out.dtype() == mx::float32) {
|
if (out.dtype() == mx::float32) {
|
||||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == mx::float16) {
|
} else if (out.dtype() == mx::float16) {
|
||||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == mx::bfloat16) {
|
} else if (out.dtype() == mx::bfloat16) {
|
||||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else if (out.dtype() == mx::complex64) {
|
} else if (out.dtype() == mx::complex64) {
|
||||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
|
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Axpby is only supported for floating point types.");
|
"Axpby is only supported for floating point types.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Primitive Accelerate Backend Implementation
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#ifdef ACCELERATE_NEW_LAPACK
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void axpby_impl_accelerate(
|
|
||||||
const mx::array& x,
|
|
||||||
const mx::array& y,
|
|
||||||
mx::array& out,
|
|
||||||
float alpha_,
|
|
||||||
float beta_) {
|
|
||||||
// Accelerate library provides catlas_saxpby which does
|
|
||||||
// Y = (alpha * X) + (beta * Y) in place
|
|
||||||
// To use it, we first copy the data in y over to the output array
|
|
||||||
|
|
||||||
// This specialization requires both x and y be contiguous in the same mode
|
|
||||||
// i.e: corresponding linear indices in both point to corresponding elements
|
|
||||||
// The data in the output array is allocated to match the strides in y
|
|
||||||
// such that x, y, and out are contiguous in the same mode and
|
|
||||||
// no transposition is needed
|
|
||||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
// We then copy over the elements using the contiguous vector specialization
|
|
||||||
copy_inplace(y, out, mx::CopyType::Vector);
|
|
||||||
|
|
||||||
// Get x and y pointers for catlas_saxpby
|
|
||||||
const T* x_ptr = x.data<T>();
|
|
||||||
T* y_ptr = out.data<T>();
|
|
||||||
|
|
||||||
T alpha = static_cast<T>(alpha_);
|
|
||||||
T beta = static_cast<T>(beta_);
|
|
||||||
|
|
||||||
// Call the inplace accelerate operator
|
|
||||||
catlas_saxpby(
|
|
||||||
/* N = */ out.size(),
|
|
||||||
/* ALPHA = */ alpha,
|
|
||||||
/* X = */ x_ptr,
|
|
||||||
/* INCX = */ 1,
|
|
||||||
/* BETA = */ beta,
|
|
||||||
/* Y = */ y_ptr,
|
|
||||||
/* INCY = */ 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Evaluate primitive on CPU using accelerate specializations */
|
|
||||||
void Axpby::eval_cpu(
|
|
||||||
const std::vector<mx::array>& inputs,
|
|
||||||
std::vector<mx::array>& outputs) {
|
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& x = inputs[0];
|
|
||||||
auto& y = inputs[1];
|
|
||||||
auto& out = outputs[0];
|
|
||||||
|
|
||||||
// Accelerate specialization for contiguous single precision float arrays
|
|
||||||
if (out.dtype() == mx::float32 &&
|
|
||||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
|
||||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
|
||||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to common backend if specializations are not available
|
|
||||||
eval(inputs, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
#else // Accelerate not available
|
|
||||||
|
|
||||||
/** Evaluate primitive on CPU falling back to common backend */
|
|
||||||
void Axpby::eval_cpu(
|
|
||||||
const std::vector<mx::array>& inputs,
|
|
||||||
std::vector<mx::array>& outputs) {
|
|
||||||
eval(inputs, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Primitive Metal Backend Implementation
|
// Primitive Metal Backend Implementation
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -217,7 +154,6 @@ void Axpby::eval_gpu(
|
|||||||
const std::vector<mx::array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
std::vector<mx::array>& outputs) {
|
std::vector<mx::array>& outputs) {
|
||||||
// Prepare inputs
|
// Prepare inputs
|
||||||
assert(inputs.size() == 2);
|
|
||||||
auto& x = inputs[0];
|
auto& x = inputs[0];
|
||||||
auto& y = inputs[1];
|
auto& y = inputs[1];
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
@@ -236,25 +172,24 @@ void Axpby::eval_gpu(
|
|||||||
// Allocate output memory with strides based on specialization
|
// Allocate output memory with strides based on specialization
|
||||||
if (contiguous_kernel) {
|
if (contiguous_kernel) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||||
x.data_size(),
|
x.data_size(),
|
||||||
x.strides(),
|
x.strides(),
|
||||||
x.flags());
|
x.flags());
|
||||||
} else {
|
} else {
|
||||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve name of kernel (corresponds to axpby.metal)
|
// Resolve name of kernel (corresponds to axpby.metal)
|
||||||
std::ostringstream kname;
|
std::string kname = "axpby_";
|
||||||
kname << "axpby_";
|
kname += (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
kname += type_to_name(out);
|
||||||
kname << type_to_name(out);
|
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Load the metal library
|
||||||
d.register_library("mlx_ext");
|
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname, lib);
|
||||||
|
|
||||||
// Prepare to encode kernel
|
// Prepare to encode kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
|||||||
const std::vector<mx::array>& inputs,
|
const std::vector<mx::array>& inputs,
|
||||||
const std::vector<int>& axes) override;
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
/** Print the primitive. */
|
/** The name of primitive. */
|
||||||
void print(std::ostream& os) override {
|
const char* name() const override {
|
||||||
os << "Axpby";
|
return "Axpby";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Equivalence check **/
|
/** Equivalence check **/
|
||||||
@@ -85,11 +85,6 @@ class Axpby : public mx::Primitive {
|
|||||||
private:
|
private:
|
||||||
float alpha_;
|
float alpha_;
|
||||||
float beta_;
|
float beta_;
|
||||||
|
|
||||||
/** Fall back implementation for evaluation on CPU */
|
|
||||||
void eval(
|
|
||||||
const std::vector<mx::array>& inputs,
|
|
||||||
std::vector<mx::array>& outputs);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace my_ext
|
} // namespace my_ext
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023-2025 Apple Inc.
|
||||||
|
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
setuptools>=42
|
setuptools>=42
|
||||||
cmake>=3.25
|
cmake>=3.25
|
||||||
mlx>=0.21.0
|
mlx>=0.21.0
|
||||||
nanobind==2.2.0
|
nanobind==2.4.0
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ from mlx_sample_extensions import axpby
|
|||||||
|
|
||||||
a = mx.ones((3, 4))
|
a = mx.ones((3, 4))
|
||||||
b = mx.ones((3, 4))
|
b = mx.ones((3, 4))
|
||||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||||
|
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
||||||
|
|
||||||
print(f"c shape: {c.shape}")
|
print(f"c shape: {c_cpu.shape}")
|
||||||
print(f"c dtype: {c.dtype}")
|
print(f"c dtype: {c_cpu.dtype}")
|
||||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
||||||
|
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
@@ -17,9 +18,13 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/version.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||||
|
|
||||||
|
# Define MLX_VERSION only in the version.cpp file.
|
||||||
|
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||||
|
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||||
|
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||||
|
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
# Disable some MSVC warnings to speed up compilation.
|
# Disable some MSVC warnings to speed up compilation.
|
||||||
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
||||||
@@ -44,5 +49,19 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
|||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_CUDA)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||||
|
else()
|
||||||
|
target_sources(mlx
|
||||||
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||||
|
else()
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@@ -4,12 +4,11 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/scheduler.h"
|
|
||||||
|
|
||||||
namespace mlx::core::allocator {
|
namespace mlx::core::allocator {
|
||||||
|
|
||||||
Buffer malloc(size_t size) {
|
Buffer malloc(size_t size) {
|
||||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
auto buffer = allocator().malloc(size);
|
||||||
if (size && !buffer.ptr()) {
|
if (size && !buffer.ptr()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||||
@@ -22,45 +21,4 @@ void free(Buffer buffer) {
|
|||||||
allocator().free(buffer);
|
allocator().free(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
|
||||||
void* ptr = std::malloc(size + sizeof(size_t));
|
|
||||||
if (ptr != nullptr) {
|
|
||||||
*static_cast<size_t*>(ptr) = size;
|
|
||||||
}
|
|
||||||
return Buffer{ptr};
|
|
||||||
}
|
|
||||||
|
|
||||||
void CommonAllocator::free(Buffer buffer) {
|
|
||||||
std::free(buffer.ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t CommonAllocator::size(Buffer buffer) const {
|
|
||||||
if (buffer.ptr() == nullptr) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return *static_cast<size_t*>(buffer.ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
Buffer malloc_or_wait(size_t size) {
|
|
||||||
auto buffer = allocator().malloc(size);
|
|
||||||
|
|
||||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
|
||||||
scheduler::wait_for_one();
|
|
||||||
buffer = allocator().malloc(size);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try swapping if needed
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (size && !buffer.ptr()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
|||||||
@@ -32,14 +32,10 @@ Buffer malloc(size_t size);
|
|||||||
|
|
||||||
void free(Buffer buffer);
|
void free(Buffer buffer);
|
||||||
|
|
||||||
// Wait for running tasks to finish and free up memory
|
|
||||||
// if allocation fails
|
|
||||||
Buffer malloc_or_wait(size_t size);
|
|
||||||
|
|
||||||
class Allocator {
|
class Allocator {
|
||||||
/** Abstract base class for a memory allocator. */
|
/** Abstract base class for a memory allocator. */
|
||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
virtual Buffer malloc(size_t size) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
virtual size_t size(Buffer buffer) const = 0;
|
virtual size_t size(Buffer buffer) const = 0;
|
||||||
|
|
||||||
@@ -53,16 +49,4 @@ class Allocator {
|
|||||||
|
|
||||||
Allocator& allocator();
|
Allocator& allocator();
|
||||||
|
|
||||||
class CommonAllocator : public Allocator {
|
|
||||||
/** A general CPU allocator. */
|
|
||||||
public:
|
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
|
||||||
virtual void free(Buffer buffer) override;
|
|
||||||
virtual size_t size(Buffer buffer) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
CommonAllocator() = default;
|
|
||||||
friend Allocator& allocator();
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
|||||||
@@ -56,6 +56,18 @@ std::vector<array> array::make_arrays(
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array array::unsafe_weak_copy(const array& other) {
|
||||||
|
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
|
||||||
|
cpy.set_data(
|
||||||
|
other.buffer(),
|
||||||
|
other.data_size(),
|
||||||
|
other.strides(),
|
||||||
|
other.flags(),
|
||||||
|
[](auto) {});
|
||||||
|
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||||
|
return cpy;
|
||||||
|
}
|
||||||
|
|
||||||
array::array(std::initializer_list<float> data)
|
array::array(std::initializer_list<float> data)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(
|
: array_desc_(std::make_shared<ArrayDesc>(
|
||||||
Shape{static_cast<ShapeElem>(data.size())},
|
Shape{static_cast<ShapeElem>(data.size())},
|
||||||
@@ -76,35 +88,27 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
|||||||
set_data(data, deleter);
|
set_data(data, deleter);
|
||||||
}
|
}
|
||||||
|
|
||||||
array::array(
|
|
||||||
allocator::Buffer data,
|
|
||||||
Shape shape,
|
|
||||||
Dtype dtype,
|
|
||||||
Strides strides,
|
|
||||||
size_t data_size,
|
|
||||||
Flags flags,
|
|
||||||
Deleter deleter)
|
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
|
||||||
set_data(data, data_size, std::move(strides), flags, deleter);
|
|
||||||
}
|
|
||||||
|
|
||||||
void array::detach() {
|
void array::detach() {
|
||||||
|
array_desc_->primitive = nullptr;
|
||||||
|
for (auto& s : array_desc_->siblings) {
|
||||||
|
s.array_desc_->primitive = nullptr;
|
||||||
|
}
|
||||||
for (auto& s : array_desc_->siblings) {
|
for (auto& s : array_desc_->siblings) {
|
||||||
s.array_desc_->inputs.clear();
|
s.array_desc_->inputs.clear();
|
||||||
s.array_desc_->siblings.clear();
|
s.array_desc_->siblings.clear();
|
||||||
s.array_desc_->position = 0;
|
s.array_desc_->position = 0;
|
||||||
s.array_desc_->primitive = nullptr;
|
|
||||||
}
|
}
|
||||||
array_desc_->inputs.clear();
|
array_desc_->inputs.clear();
|
||||||
array_desc_->siblings.clear();
|
array_desc_->siblings.clear();
|
||||||
array_desc_->position = 0;
|
array_desc_->position = 0;
|
||||||
array_desc_->primitive = nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool array::is_available() const {
|
bool array::is_available() const {
|
||||||
if (status() == Status::available) {
|
if (status() == Status::available) {
|
||||||
return true;
|
return true;
|
||||||
} else if (status() == Status::evaluated && event().is_signaled()) {
|
} else if (
|
||||||
|
status() == Status::evaluated &&
|
||||||
|
(!event().valid() || event().is_signaled())) {
|
||||||
set_status(Status::available);
|
set_status(Status::available);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -113,7 +117,10 @@ bool array::is_available() const {
|
|||||||
|
|
||||||
void array::wait() {
|
void array::wait() {
|
||||||
if (!is_available()) {
|
if (!is_available()) {
|
||||||
|
if (event().valid()) {
|
||||||
event().wait();
|
event().wait();
|
||||||
|
detach_event();
|
||||||
|
}
|
||||||
set_status(Status::available);
|
set_status(Status::available);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -174,34 +181,13 @@ void array::copy_shared_buffer(const array& other) {
|
|||||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::move_shared_buffer(
|
|
||||||
array other,
|
|
||||||
const Strides& strides,
|
|
||||||
Flags flags,
|
|
||||||
size_t data_size,
|
|
||||||
size_t offset /* = 0 */) {
|
|
||||||
array_desc_->data = std::move(other.array_desc_->data);
|
|
||||||
array_desc_->strides = strides;
|
|
||||||
array_desc_->flags = flags;
|
|
||||||
array_desc_->data_size = data_size;
|
|
||||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
|
||||||
auto data_ptr = other.array_desc_->data_ptr;
|
|
||||||
other.array_desc_->data_ptr = nullptr;
|
|
||||||
array_desc_->data_ptr =
|
|
||||||
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
void array::move_shared_buffer(array other) {
|
|
||||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
array::~array() {
|
array::~array() {
|
||||||
if (array_desc_ == nullptr) {
|
if (array_desc_ == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ignore arrays that might be detached during eval
|
// Detached/detaching
|
||||||
if (status() == array::Status::scheduled) {
|
if (array_desc_->primitive == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
54
mlx/array.h
54
mlx/array.h
@@ -10,6 +10,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
|
#include "mlx/small_vector.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -18,8 +19,8 @@ class Primitive;
|
|||||||
|
|
||||||
using Deleter = std::function<void(allocator::Buffer)>;
|
using Deleter = std::function<void(allocator::Buffer)>;
|
||||||
using ShapeElem = int32_t;
|
using ShapeElem = int32_t;
|
||||||
using Shape = std::vector<ShapeElem>;
|
using Shape = SmallVector<ShapeElem>;
|
||||||
using Strides = std::vector<int64_t>;
|
using Strides = SmallVector<int64_t>;
|
||||||
|
|
||||||
class array {
|
class array {
|
||||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||||
@@ -199,6 +200,13 @@ class array {
|
|||||||
const std::shared_ptr<Primitive>& primitive,
|
const std::shared_ptr<Primitive>& primitive,
|
||||||
const std::vector<array>& inputs);
|
const std::vector<array>& inputs);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a new array that refers to the same data as the input but with a
|
||||||
|
* non-owning pointer to it. Note the array is detached from the graph and has
|
||||||
|
* no inputs, siblings or primitive.
|
||||||
|
*/
|
||||||
|
static array unsafe_weak_copy(const array& other);
|
||||||
|
|
||||||
/** A unique identifier for an array. */
|
/** A unique identifier for an array. */
|
||||||
std::uintptr_t id() const {
|
std::uintptr_t id() const {
|
||||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||||
@@ -217,6 +225,10 @@ class array {
|
|||||||
// Not copyable
|
// Not copyable
|
||||||
Data(const Data& d) = delete;
|
Data(const Data& d) = delete;
|
||||||
Data& operator=(const Data& d) = delete;
|
Data& operator=(const Data& d) = delete;
|
||||||
|
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||||
|
o.buffer = allocator::Buffer(nullptr);
|
||||||
|
o.d = [](allocator::Buffer) {};
|
||||||
|
}
|
||||||
~Data() {
|
~Data() {
|
||||||
d(buffer);
|
d(buffer);
|
||||||
}
|
}
|
||||||
@@ -243,18 +255,6 @@ class array {
|
|||||||
bool col_contiguous : 1;
|
bool col_contiguous : 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** Build an array from all the info held by the array description. Including
|
|
||||||
* the buffer, strides, flags.
|
|
||||||
*/
|
|
||||||
explicit array(
|
|
||||||
allocator::Buffer data,
|
|
||||||
Shape shape,
|
|
||||||
Dtype dtype,
|
|
||||||
Strides strides,
|
|
||||||
size_t data_size,
|
|
||||||
Flags flags,
|
|
||||||
Deleter deleter = allocator::free);
|
|
||||||
|
|
||||||
/** The array's primitive. */
|
/** The array's primitive. */
|
||||||
Primitive& primitive() const {
|
Primitive& primitive() const {
|
||||||
return *(array_desc_->primitive);
|
return *(array_desc_->primitive);
|
||||||
@@ -344,11 +344,11 @@ class array {
|
|||||||
return allocator::allocator().size(buffer());
|
return allocator::allocator().size(buffer());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a copy of the shared pointer
|
// Return the shared pointer to the array::Data struct
|
||||||
// to the array::Data struct
|
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||||
std::shared_ptr<Data> data_shared_ptr() const {
|
|
||||||
return array_desc_->data;
|
return array_desc_->data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return a raw pointer to the arrays data
|
// Return a raw pointer to the arrays data
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* data() {
|
T* data() {
|
||||||
@@ -361,15 +361,10 @@ class array {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enum Status {
|
enum Status {
|
||||||
// The ouptut of a computation which has not been scheduled.
|
// The output of a computation which has not been scheduled.
|
||||||
// For example, the status of `x` in `auto x = a + b`.
|
// For example, the status of `x` in `auto x = a + b`.
|
||||||
unscheduled,
|
unscheduled,
|
||||||
|
|
||||||
// The ouptut of a computation which has been scheduled but `eval_*` has
|
|
||||||
// not yet been called on the array's primitive. A possible
|
|
||||||
// status of `x` in `auto x = a + b; eval(x);`
|
|
||||||
scheduled,
|
|
||||||
|
|
||||||
// The array's `eval_*` function has been run, but the computation is not
|
// The array's `eval_*` function has been run, but the computation is not
|
||||||
// necessarily complete. The array will have memory allocated and if it is
|
// necessarily complete. The array will have memory allocated and if it is
|
||||||
// not a tracer then it will be detached from the graph.
|
// not a tracer then it will be detached from the graph.
|
||||||
@@ -406,6 +401,10 @@ class array {
|
|||||||
array_desc_->event = std::move(e);
|
array_desc_->event = std::move(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void detach_event() const {
|
||||||
|
array_desc_->event = Event{};
|
||||||
|
}
|
||||||
|
|
||||||
// Mark the array as a tracer array (true) or not.
|
// Mark the array as a tracer array (true) or not.
|
||||||
void set_tracer(bool is_tracer) {
|
void set_tracer(bool is_tracer) {
|
||||||
array_desc_->is_tracer = is_tracer;
|
array_desc_->is_tracer = is_tracer;
|
||||||
@@ -431,15 +430,6 @@ class array {
|
|||||||
|
|
||||||
void copy_shared_buffer(const array& other);
|
void copy_shared_buffer(const array& other);
|
||||||
|
|
||||||
void move_shared_buffer(
|
|
||||||
array other,
|
|
||||||
const Strides& strides,
|
|
||||||
Flags flags,
|
|
||||||
size_t data_size,
|
|
||||||
size_t offset = 0);
|
|
||||||
|
|
||||||
void move_shared_buffer(array other);
|
|
||||||
|
|
||||||
void overwrite_descriptor(const array& other) {
|
void overwrite_descriptor(const array& other) {
|
||||||
array_desc_ = other.array_desc_;
|
array_desc_ = other.array_desc_;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
|||||||
@@ -38,25 +38,20 @@ inline void set_binary_op_output_data(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
array& out,
|
array& out,
|
||||||
BinaryOpType bopt,
|
BinaryOpType bopt) {
|
||||||
bool donate_with_move = false) {
|
|
||||||
bool b_donatable = is_donatable(b, out);
|
bool b_donatable = is_donatable(b, out);
|
||||||
bool a_donatable = is_donatable(a, out);
|
bool a_donatable = is_donatable(a, out);
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::ScalarScalar:
|
case BinaryOpType::ScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
if (b_donatable) {
|
if (b_donatable) {
|
||||||
if (donate_with_move) {
|
|
||||||
out.move_shared_buffer(b);
|
|
||||||
} else {
|
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
allocator::malloc(b.data_size() * out.itemsize()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
@@ -64,14 +59,10 @@ inline void set_binary_op_output_data(
|
|||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorScalar:
|
case BinaryOpType::VectorScalar:
|
||||||
if (a_donatable) {
|
if (a_donatable) {
|
||||||
if (donate_with_move) {
|
|
||||||
out.move_shared_buffer(a);
|
|
||||||
} else {
|
|
||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
allocator::malloc(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@@ -79,20 +70,12 @@ inline void set_binary_op_output_data(
|
|||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorVector:
|
case BinaryOpType::VectorVector:
|
||||||
if (a_donatable) {
|
if (a_donatable) {
|
||||||
if (donate_with_move) {
|
|
||||||
out.move_shared_buffer(a);
|
|
||||||
} else {
|
|
||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
}
|
|
||||||
} else if (b_donatable) {
|
} else if (b_donatable) {
|
||||||
if (donate_with_move) {
|
|
||||||
out.move_shared_buffer(b);
|
|
||||||
} else {
|
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
allocator::malloc(a.data_size() * out.itemsize()),
|
||||||
a.data_size(),
|
a.data_size(),
|
||||||
a.strides(),
|
a.strides(),
|
||||||
a.flags());
|
a.flags());
|
||||||
@@ -100,20 +83,12 @@ inline void set_binary_op_output_data(
|
|||||||
break;
|
break;
|
||||||
case BinaryOpType::General:
|
case BinaryOpType::General:
|
||||||
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
||||||
if (donate_with_move) {
|
|
||||||
out.move_shared_buffer(a);
|
|
||||||
} else {
|
|
||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
}
|
|
||||||
} else if (
|
} else if (
|
||||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||||
if (donate_with_move) {
|
|
||||||
out.move_shared_buffer(b);
|
|
||||||
} else {
|
|
||||||
out.copy_shared_buffer(b);
|
out.copy_shared_buffer(b);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
24
mlx/backend/common/broadcasting.cpp
Normal file
24
mlx/backend/common/broadcasting.cpp
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void broadcast(const array& in, array& out) {
|
||||||
|
if (out.size() == 0) {
|
||||||
|
out.set_data(nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Strides strides(out.ndim(), 0);
|
||||||
|
int diff = out.ndim() - in.ndim();
|
||||||
|
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||||
|
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||||
|
}
|
||||||
|
auto flags = in.flags();
|
||||||
|
if (out.size() > in.size()) {
|
||||||
|
flags.row_contiguous = flags.col_contiguous = false;
|
||||||
|
}
|
||||||
|
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void encode_wait(Event e);
|
void broadcast(const array& in, array& out);
|
||||||
|
|
||||||
void encode_signal(Event e);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
157
mlx/backend/common/buffer_cache.h
Normal file
157
mlx/backend/common/buffer_cache.h
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class BufferCache {
|
||||||
|
public:
|
||||||
|
BufferCache(
|
||||||
|
size_t page_size,
|
||||||
|
std::function<size_t(T*)> get_size,
|
||||||
|
std::function<void(T*)> free)
|
||||||
|
: page_size_(page_size),
|
||||||
|
get_size_(std::move(get_size)),
|
||||||
|
free_(std::move(free)) {}
|
||||||
|
|
||||||
|
~BufferCache() {
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferCache(const BufferCache&) = delete;
|
||||||
|
BufferCache& operator=(const BufferCache&) = delete;
|
||||||
|
|
||||||
|
T* reuse_from_cache(size_t size) {
|
||||||
|
// Find the closest buffer in pool.
|
||||||
|
auto it = buffer_pool_.lower_bound(size);
|
||||||
|
if (it == buffer_pool_.end() ||
|
||||||
|
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect from the cache.
|
||||||
|
T* buf = it->second->buf;
|
||||||
|
pool_size_ -= it->first;
|
||||||
|
|
||||||
|
// Remove from record.
|
||||||
|
remove_from_list(it->second);
|
||||||
|
buffer_pool_.erase(it);
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
void recycle_to_cache(T* buf) {
|
||||||
|
assert(buf);
|
||||||
|
// Add to cache.
|
||||||
|
BufferHolder* bh = new BufferHolder(buf);
|
||||||
|
add_at_head(bh);
|
||||||
|
size_t size = get_size_(buf);
|
||||||
|
pool_size_ += size;
|
||||||
|
buffer_pool_.emplace(size, bh);
|
||||||
|
}
|
||||||
|
|
||||||
|
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||||
|
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||||
|
return clear();
|
||||||
|
} else {
|
||||||
|
int n_release = 0;
|
||||||
|
size_t total_bytes_freed = 0;
|
||||||
|
|
||||||
|
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||||
|
// Release buffer.
|
||||||
|
size_t size = get_size_(tail_->buf);
|
||||||
|
total_bytes_freed += size;
|
||||||
|
free_(tail_->buf);
|
||||||
|
n_release++;
|
||||||
|
|
||||||
|
// Remove from record.
|
||||||
|
auto its = buffer_pool_.equal_range(size);
|
||||||
|
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||||
|
return el.second == tail_;
|
||||||
|
});
|
||||||
|
assert(it != buffer_pool_.end());
|
||||||
|
buffer_pool_.erase(it);
|
||||||
|
remove_from_list(tail_);
|
||||||
|
}
|
||||||
|
|
||||||
|
pool_size_ -= total_bytes_freed;
|
||||||
|
return n_release;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int clear() {
|
||||||
|
int n_release = 0;
|
||||||
|
for (auto& [size, holder] : buffer_pool_) {
|
||||||
|
free_(holder->buf);
|
||||||
|
n_release++;
|
||||||
|
delete holder;
|
||||||
|
}
|
||||||
|
buffer_pool_.clear();
|
||||||
|
pool_size_ = 0;
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
return n_release;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t cache_size() const {
|
||||||
|
return pool_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t page_size() const {
|
||||||
|
return page_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct BufferHolder {
|
||||||
|
public:
|
||||||
|
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||||
|
|
||||||
|
BufferHolder* prev{nullptr};
|
||||||
|
BufferHolder* next{nullptr};
|
||||||
|
T* buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
void add_at_head(BufferHolder* to_add) {
|
||||||
|
if (!head_) {
|
||||||
|
head_ = to_add;
|
||||||
|
tail_ = to_add;
|
||||||
|
} else {
|
||||||
|
head_->prev = to_add;
|
||||||
|
to_add->next = head_;
|
||||||
|
head_ = to_add;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void remove_from_list(BufferHolder* to_remove) {
|
||||||
|
if (to_remove->prev && to_remove->next) { // if middle
|
||||||
|
to_remove->prev->next = to_remove->next;
|
||||||
|
to_remove->next->prev = to_remove->prev;
|
||||||
|
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||||
|
tail_ = to_remove->prev;
|
||||||
|
tail_->next = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||||
|
head_ = to_remove->next;
|
||||||
|
head_->prev = nullptr;
|
||||||
|
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||||
|
head_ = nullptr;
|
||||||
|
tail_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
delete to_remove;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||||
|
BufferHolder* head_{nullptr};
|
||||||
|
BufferHolder* tail_{nullptr};
|
||||||
|
size_t pool_size_{0};
|
||||||
|
|
||||||
|
const size_t page_size_;
|
||||||
|
std::function<size_t(T*)> get_size_;
|
||||||
|
std::function<void(T*)> free_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/broadcasting.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@@ -39,24 +40,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
// rely on data_size anyway.
|
// rely on data_size anyway.
|
||||||
size_t data_size = out.size();
|
size_t data_size = out.size();
|
||||||
|
|
||||||
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||||
}
|
|
||||||
|
|
||||||
void broadcast(const array& in, array& out) {
|
|
||||||
if (out.size() == 0) {
|
|
||||||
out.set_data(nullptr);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Strides strides(out.ndim(), 0);
|
|
||||||
int diff = out.ndim() - in.ndim();
|
|
||||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
|
||||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
|
||||||
}
|
|
||||||
auto flags = in.flags();
|
|
||||||
if (out.size() > in.size()) {
|
|
||||||
flags.row_contiguous = flags.col_contiguous = false;
|
|
||||||
}
|
|
||||||
move_or_copy(in, out, strides, flags, in.data_size());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -69,7 +53,7 @@ void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
move_or_copy(inputs[0], out);
|
out.copy_shared_buffer(inputs[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CustomTransforms::eval(
|
void CustomTransforms::eval(
|
||||||
@@ -78,7 +62,7 @@ void CustomTransforms::eval(
|
|||||||
assert(inputs.size() > outputs.size());
|
assert(inputs.size() > outputs.size());
|
||||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||||
i++, j++) {
|
i++, j++) {
|
||||||
move_or_copy(inputs[j], outputs[i]);
|
outputs[i].copy_shared_buffer(inputs[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,7 +71,7 @@ void Depends::eval(
|
|||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
assert(inputs.size() > outputs.size());
|
assert(inputs.size() > outputs.size());
|
||||||
for (int i = 0; i < outputs.size(); i++) {
|
for (int i = 0; i < outputs.size(); i++) {
|
||||||
move_or_copy(inputs[i], outputs[i]);
|
outputs[i].copy_shared_buffer(inputs[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,12 +82,12 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
for (auto ax : axes_) {
|
for (auto ax : axes_) {
|
||||||
strides.insert(strides.begin() + ax, 1);
|
strides.insert(strides.begin() + ax, 1);
|
||||||
}
|
}
|
||||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
double numel = 1;
|
double numel = 1;
|
||||||
for (auto ax : axes_) {
|
for (auto ax : axes_) {
|
||||||
@@ -210,7 +194,7 @@ void shared_buffer_reshape(
|
|||||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||||
}
|
}
|
||||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Split::eval(
|
void Split::eval(
|
||||||
@@ -276,12 +260,12 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
strides.push_back(in.strides(i));
|
strides.push_back(in.strides(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
move_or_copy(inputs[0], out);
|
out.copy_shared_buffer(inputs[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -315,7 +299,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
b_stride *= out.shape(ri);
|
b_stride *= out.shape(ri);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -15,6 +14,8 @@ void print_constant(std::ostream& os, const array& x) {
|
|||||||
return print_float_constant<float16_t>(os, x);
|
return print_float_constant<float16_t>(os, x);
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return print_float_constant<bfloat16_t>(os, x);
|
return print_float_constant<bfloat16_t>(os, x);
|
||||||
|
case float64:
|
||||||
|
return print_float_constant<double>(os, x);
|
||||||
case complex64:
|
case complex64:
|
||||||
return print_complex_constant<complex64_t>(os, x);
|
return print_complex_constant<complex64_t>(os, x);
|
||||||
case int8:
|
case int8:
|
||||||
@@ -51,6 +52,8 @@ std::string get_type_string(Dtype d) {
|
|||||||
return "float16_t";
|
return "float16_t";
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return "bfloat16_t";
|
return "bfloat16_t";
|
||||||
|
case float64:
|
||||||
|
return "double";
|
||||||
case complex64:
|
case complex64:
|
||||||
return "complex64_t";
|
return "complex64_t";
|
||||||
case bool_:
|
case bool_:
|
||||||
@@ -79,55 +82,6 @@ std::string get_type_string(Dtype d) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
const std::vector<array>& tape,
|
|
||||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
|
||||||
NodeNamer namer;
|
|
||||||
std::ostringstream os;
|
|
||||||
std::ostringstream constant_hasher;
|
|
||||||
|
|
||||||
// Fill the input names. This is not really necessary, I just like having A,
|
|
||||||
// B, C, ... as the inputs.
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
namer.get_name(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The primitives describing the tape. For unary and binary primitives this
|
|
||||||
// must be enough to describe the full computation.
|
|
||||||
for (auto& a : tape) {
|
|
||||||
// name and type of output
|
|
||||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
|
||||||
// computation performed
|
|
||||||
a.primitive().print(os);
|
|
||||||
// name of inputs to the function
|
|
||||||
for (auto& inp : a.inputs()) {
|
|
||||||
os << namer.get_name(inp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os << "_";
|
|
||||||
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
||||||
os << "C";
|
|
||||||
print_constant(constant_hasher, x);
|
|
||||||
} else {
|
|
||||||
os << (is_scalar(x) ? "S" : "V");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
os << "_";
|
|
||||||
for (auto& x : inputs) {
|
|
||||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
os << kindof(x.dtype()) << x.itemsize();
|
|
||||||
}
|
|
||||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
|
||||||
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool compiled_check_contiguity(
|
bool compiled_check_contiguity(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const Shape& shape) {
|
const Shape& shape) {
|
||||||
@@ -159,10 +113,8 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& inputs_,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
bool contiguous) {
|
||||||
bool contiguous,
|
|
||||||
bool move_buffers /* = false */) {
|
|
||||||
if (contiguous) {
|
if (contiguous) {
|
||||||
int o = 0;
|
int o = 0;
|
||||||
Strides strides;
|
Strides strides;
|
||||||
@@ -176,14 +128,9 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
in.is_donatable() &&
|
in.is_donatable() && is_constant(i)) {
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
|
||||||
if (move_buffers) {
|
|
||||||
outputs[o++].move_shared_buffer(in);
|
|
||||||
} else {
|
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// Get representative input flags to properly set non-donated outputs
|
// Get representative input flags to properly set non-donated outputs
|
||||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||||
strides = in.strides();
|
strides = in.strides();
|
||||||
@@ -193,7 +140,7 @@ void compiled_allocate_outputs(
|
|||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(
|
outputs[o].set_data(
|
||||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||||
data_size,
|
data_size,
|
||||||
strides,
|
strides,
|
||||||
flags);
|
flags);
|
||||||
@@ -209,21 +156,86 @@ void compiled_allocate_outputs(
|
|||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
is_constant(i)) {
|
||||||
if (move_buffers) {
|
|
||||||
outputs[o].move_shared_buffer(
|
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
|
||||||
} else {
|
|
||||||
outputs[o].copy_shared_buffer(
|
outputs[o].copy_shared_buffer(
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
}
|
|
||||||
o++;
|
o++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; o < outputs.size(); ++o) {
|
for (; o < outputs.size(); ++o) {
|
||||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const array& out,
|
||||||
|
const std::function<bool(size_t)>& is_constant) {
|
||||||
|
const Shape& shape = out.shape();
|
||||||
|
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||||
|
if (contiguous) {
|
||||||
|
return {true, shape, {}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Strides> strides_vec{out.strides()};
|
||||||
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
// Skip constants.
|
||||||
|
if (is_constant(i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip scalar inputs.
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
if (is_scalar(x)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast the inputs to the output shape.
|
||||||
|
Strides xstrides;
|
||||||
|
size_t j = 0;
|
||||||
|
for (; j < shape.size() - x.ndim(); ++j) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(out.strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||||
|
if (x.shape(i) == 1) {
|
||||||
|
if (shape[j] == 1) {
|
||||||
|
xstrides.push_back(out.strides()[j]);
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
xstrides.push_back(x.strides()[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
strides_vec.push_back(std::move(xstrides));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||||
|
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool compiled_use_large_index(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
bool contiguous) {
|
||||||
|
if (contiguous) {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (const auto& in : inputs) {
|
||||||
|
max_size = std::max(max_size, in.data_size());
|
||||||
|
}
|
||||||
|
return max_size > UINT32_MAX;
|
||||||
|
} else {
|
||||||
|
size_t max_size = 0;
|
||||||
|
for (const auto& o : outputs) {
|
||||||
|
max_size = std::max(max_size, o.size());
|
||||||
|
}
|
||||||
|
return max_size > UINT32_MAX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <sstream>
|
|
||||||
#include <unordered_set>
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -14,19 +13,17 @@ inline bool is_static_cast(const Primitive& p) {
|
|||||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
const std::vector<array>& outputs,
|
|
||||||
const std::vector<array>& tape,
|
|
||||||
const std::unordered_set<uintptr_t>& constant_ids);
|
|
||||||
|
|
||||||
std::string get_type_string(Dtype d);
|
std::string get_type_string(Dtype d);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void print_float_constant(std::ostream& os, const array& x) {
|
void print_float_constant(std::ostream& os, const array& x) {
|
||||||
auto old_precision = os.precision();
|
auto old_precision = os.precision();
|
||||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
if constexpr (std::is_same_v<T, double>) {
|
||||||
<< x.item<T>() << std::setprecision(old_precision);
|
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
||||||
|
} else {
|
||||||
|
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
||||||
|
}
|
||||||
|
os << x.item<T>() << std::setprecision(old_precision);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@@ -60,9 +57,19 @@ bool compiled_check_contiguity(
|
|||||||
void compiled_allocate_outputs(
|
void compiled_allocate_outputs(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
const std::vector<array>& inputs_,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
bool contiguous);
|
||||||
bool contiguous,
|
|
||||||
bool move_buffers = false);
|
// Collapse contiguous dims ignoring scalars and constants.
|
||||||
|
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const array& out,
|
||||||
|
const std::function<bool(size_t)>& is_constant);
|
||||||
|
|
||||||
|
// Return whether the kernel should use large index.
|
||||||
|
bool compiled_use_large_index(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<array>& outputs,
|
||||||
|
bool contiguous);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -22,4 +22,25 @@ enum class CopyType {
|
|||||||
GeneralGeneral
|
GeneralGeneral
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||||
|
if (ctype == CopyType::Vector) {
|
||||||
|
// If the input is donateable, we are doing a vector copy and the types
|
||||||
|
// have the same size, then the input buffer can hold the output.
|
||||||
|
if (is_donatable(in, out)) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(in.data_size() * out.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -99,6 +99,10 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
|||||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (n > (1 << 26)) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||||
|
}
|
||||||
return {n, m};
|
return {n, m};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,8 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "mlx/backend/common/load.h"
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@@ -26,26 +27,31 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void load(
|
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
array& out,
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
size_t offset,
|
auto read_task = [out_ptr = out.data<char>(),
|
||||||
const std::shared_ptr<io::Reader>& reader,
|
size = out.size(),
|
||||||
bool swap_endianness_) {
|
itemsize = out.itemsize(),
|
||||||
reader->read(out.data<char>(), out.nbytes(), offset);
|
offset = offset_,
|
||||||
|
reader = reader_,
|
||||||
|
swap_endianness_ = swap_endianness_]() mutable {
|
||||||
|
reader->read(out_ptr, size * itemsize, offset);
|
||||||
if (swap_endianness_) {
|
if (swap_endianness_) {
|
||||||
switch (out.itemsize()) {
|
switch (itemsize) {
|
||||||
case 2:
|
case 2:
|
||||||
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||||
|
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -1,14 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/io/load.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
void load(
|
|
||||||
array& out,
|
|
||||||
size_t offset,
|
|
||||||
const std::shared_ptr<io::Reader>& reader,
|
|
||||||
bool swap_endianess);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
67
mlx/backend/common/matmul.h
Normal file
67
mlx/backend/common/matmul.h
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||||
|
const array& a,
|
||||||
|
const array& b) {
|
||||||
|
if (a.ndim() == 2) {
|
||||||
|
return {{1}, {0}, {0}};
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||||
|
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
|
||||||
|
auto [batch_shape, batch_strides] =
|
||||||
|
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||||
|
|
||||||
|
auto a_batch_strides = batch_strides[0];
|
||||||
|
auto b_batch_strides = batch_strides[1];
|
||||||
|
|
||||||
|
if (batch_shape.empty()) {
|
||||||
|
batch_shape.push_back(1);
|
||||||
|
a_batch_strides.push_back(0);
|
||||||
|
b_batch_strides.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||||
|
collapse_batches(const array& a, const array& b, const array& c) {
|
||||||
|
if (a.ndim() == 2) {
|
||||||
|
return {{1}, {0}, {0}, {0}};
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||||
|
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||||
|
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||||
|
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||||
|
|
||||||
|
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||||
|
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||||
|
|
||||||
|
auto A_batch_stride = batch_strides[0];
|
||||||
|
auto B_batch_stride = batch_strides[1];
|
||||||
|
auto C_batch_stride = batch_strides[2];
|
||||||
|
|
||||||
|
if (batch_shape.empty()) {
|
||||||
|
batch_shape.push_back(1);
|
||||||
|
A_batch_stride.push_back(0);
|
||||||
|
B_batch_stride.push_back(0);
|
||||||
|
C_batch_stride.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(
|
||||||
|
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -5,11 +5,9 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
Shape shape,
|
||||||
|
Strides strides,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto shape = x.shape();
|
|
||||||
auto strides = x.strides();
|
|
||||||
|
|
||||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||||
int a = axes[i];
|
int a = axes[i];
|
||||||
shape.erase(shape.begin() + a);
|
shape.erase(shape.begin() + a);
|
||||||
@@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
|||||||
return std::make_pair(shape, strides);
|
return std::make_pair(shape, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
|
const array& x,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
auto shape = x.shape();
|
||||||
|
auto strides = x.strides();
|
||||||
|
return shapes_without_reduction_axes(
|
||||||
|
std::move(shape), std::move(strides), axes);
|
||||||
|
}
|
||||||
|
|
||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
|
|||||||
@@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
|||||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
const array& x,
|
const array& x,
|
||||||
const std::vector<int>& axes);
|
const std::vector<int>& axes);
|
||||||
|
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||||
|
Shape shape,
|
||||||
|
Strides strides,
|
||||||
|
const std::vector<int>& axes);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ void shared_buffer_slice(
|
|||||||
flags.col_contiguous = is_col_contiguous;
|
flags.col_contiguous = is_col_contiguous;
|
||||||
flags.contiguous = (no_bsx_size == data_size);
|
flags.contiguous = (no_bsx_size == data_size);
|
||||||
|
|
||||||
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
|
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
void slice(
|
void slice(
|
||||||
|
|||||||
@@ -36,15 +36,10 @@ inline void set_ternary_op_output_data(
|
|||||||
const array& b,
|
const array& b,
|
||||||
const array& c,
|
const array& c,
|
||||||
array& out,
|
array& out,
|
||||||
TernaryOpType topt,
|
TernaryOpType topt) {
|
||||||
bool donate_with_move = false) {
|
auto maybe_donate = [&out](const array& x) {
|
||||||
auto maybe_donate = [&out, donate_with_move](const array& x) {
|
|
||||||
if (is_donatable(x, out)) {
|
if (is_donatable(x, out)) {
|
||||||
if (donate_with_move) {
|
|
||||||
out.move_shared_buffer(x);
|
|
||||||
} else {
|
|
||||||
out.copy_shared_buffer(x);
|
out.copy_shared_buffer(x);
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@@ -53,12 +48,12 @@ inline void set_ternary_op_output_data(
|
|||||||
switch (topt) {
|
switch (topt) {
|
||||||
case TernaryOpType::ScalarScalarScalar:
|
case TernaryOpType::ScalarScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
||||||
break;
|
break;
|
||||||
case TernaryOpType::VectorVectorVector:
|
case TernaryOpType::VectorVectorVector:
|
||||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
allocator::malloc(out.itemsize() * b.data_size()),
|
||||||
b.data_size(),
|
b.data_size(),
|
||||||
b.strides(),
|
b.strides(),
|
||||||
b.flags());
|
b.flags());
|
||||||
@@ -69,7 +64,7 @@ inline void set_ternary_op_output_data(
|
|||||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
26
mlx/backend/common/unary.h
Normal file
26
mlx/backend/common/unary.h
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline void set_unary_output_data(const array& in, array& out) {
|
||||||
|
if (in.flags().contiguous) {
|
||||||
|
if (is_donatable(in, out)) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
out.set_data(
|
||||||
|
allocator::malloc(in.data_size() * out.itemsize()),
|
||||||
|
in.data_size(),
|
||||||
|
in.strides(),
|
||||||
|
in.flags());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -1,29 +1,20 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void move_or_copy(const array& in, array& out) {
|
std::filesystem::path current_binary_dir() {
|
||||||
if (in.is_donatable()) {
|
static std::filesystem::path binary_dir = []() {
|
||||||
out.move_shared_buffer(in);
|
Dl_info info;
|
||||||
} else {
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||||
out.copy_shared_buffer(in);
|
throw std::runtime_error("Unable to get current binary dir.");
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void move_or_copy(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
const Strides& strides,
|
|
||||||
array::Flags flags,
|
|
||||||
size_t data_size,
|
|
||||||
size_t offset /* = 0 */) {
|
|
||||||
if (in.is_donatable()) {
|
|
||||||
out.move_shared_buffer(in, strides, flags, data_size, offset);
|
|
||||||
} else {
|
|
||||||
out.copy_shared_buffer(in, strides, flags, data_size, offset);
|
|
||||||
}
|
}
|
||||||
|
return std::filesystem::path(info.dli_fname).parent_path();
|
||||||
|
}();
|
||||||
|
return binary_dir;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||||
@@ -123,4 +114,145 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||||
|
int pows[3] = {0, 0, 0};
|
||||||
|
int sum = 0;
|
||||||
|
while (true) {
|
||||||
|
int presum = sum;
|
||||||
|
// Check all the pows
|
||||||
|
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||||
|
pows[0]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||||
|
pows[1]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == 10) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||||
|
pows[2]++;
|
||||||
|
sum++;
|
||||||
|
}
|
||||||
|
if (sum == presum || sum == pow2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
||||||
|
// Dims with strides of 0 are ignored as they
|
||||||
|
// correspond to broadcasted dimensions
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dims get_2d_grid_dims_common(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor) {
|
||||||
|
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||||
|
// divided by divisor.
|
||||||
|
size_t grid_x = 1;
|
||||||
|
size_t grid_y = 1;
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
if (strides[i] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// No need to add this shape we can just remove it from the divisor.
|
||||||
|
if (divisor % shape[i] == 0) {
|
||||||
|
divisor /= shape[i];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grid_x * shape[i] < UINT32_MAX) {
|
||||||
|
grid_x *= shape[i];
|
||||||
|
} else {
|
||||||
|
grid_y *= shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (divisor > 1) {
|
||||||
|
if (grid_x % divisor == 0) {
|
||||||
|
grid_x /= divisor;
|
||||||
|
divisor = 1;
|
||||||
|
} else if (grid_y % divisor == 0) {
|
||||||
|
grid_y /= divisor;
|
||||||
|
divisor = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||||
|
throw std::runtime_error("Unable to safely factor shape.");
|
||||||
|
}
|
||||||
|
if (grid_y > grid_x) {
|
||||||
|
std::swap(grid_x, grid_y);
|
||||||
|
}
|
||||||
|
if (divisor > 1) {
|
||||||
|
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
|
||||||
|
}
|
||||||
|
return std::make_tuple(
|
||||||
|
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||||
|
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
||||||
|
auto gx = (dim0 + bx - 1) / bx;
|
||||||
|
auto gy = (dim1 + by - 1) / by;
|
||||||
|
auto gz = (dim2 + bz - 1) / bz;
|
||||||
|
|
||||||
|
return std::make_pair(
|
||||||
|
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||||
|
}
|
||||||
|
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (axis1 < 0) {
|
||||||
|
axis1 += ndim;
|
||||||
|
}
|
||||||
|
if (axis2 < 0) {
|
||||||
|
axis2 += ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = x.shape();
|
||||||
|
std::swap(shape[axis1], shape[axis2]);
|
||||||
|
auto strides = x.strides();
|
||||||
|
std::swap(strides[axis1], strides[axis2]);
|
||||||
|
|
||||||
|
auto [data_size, row_contiguous, col_contiguous] =
|
||||||
|
check_contiguity(shape, strides);
|
||||||
|
bool contiguous = data_size == x.data_size();
|
||||||
|
|
||||||
|
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||||
|
out.copy_shared_buffer(
|
||||||
|
x,
|
||||||
|
std::move(strides),
|
||||||
|
{contiguous, row_contiguous, col_contiguous},
|
||||||
|
x.data_size());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,12 +2,17 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Return the directory that contains current shared library.
|
||||||
|
std::filesystem::path current_binary_dir();
|
||||||
|
|
||||||
inline int64_t
|
inline int64_t
|
||||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||||
int64_t loc = 0;
|
int64_t loc = 0;
|
||||||
@@ -70,6 +75,31 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
|||||||
const array& a,
|
const array& a,
|
||||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||||
|
|
||||||
|
// Compute the thread block dimensions which fit the given
|
||||||
|
// input dimensions.
|
||||||
|
// - The thread block dimensions will be powers of two
|
||||||
|
// - The thread block size will be less than 2^pow2
|
||||||
|
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||||
|
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||||
|
|
||||||
|
// Computes a 2D grid where each element is < UINT_MAX
|
||||||
|
// Assumes:
|
||||||
|
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||||
|
// - shape and strides correspond to a contiguous (no holes) but
|
||||||
|
// possibly broadcasted array
|
||||||
|
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
||||||
|
|
||||||
|
// Same as above but we do an implicit division with divisor.
|
||||||
|
// Basically, equivalent to factorizing
|
||||||
|
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||||
|
Dims get_2d_grid_dims_common(
|
||||||
|
const Shape& shape,
|
||||||
|
const Strides& strides,
|
||||||
|
size_t divisor);
|
||||||
|
|
||||||
|
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
||||||
|
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
||||||
|
|
||||||
struct ContiguousIterator {
|
struct ContiguousIterator {
|
||||||
inline void step() {
|
inline void step() {
|
||||||
int dims = shape_.size();
|
int dims = shape_.size();
|
||||||
@@ -159,19 +189,20 @@ inline bool is_donatable(const array& in, const array& out) {
|
|||||||
in.buffer_size() <= out.nbytes() + donation_extra;
|
in.buffer_size() <= out.nbytes() + donation_extra;
|
||||||
}
|
}
|
||||||
|
|
||||||
void move_or_copy(const array& in, array& out);
|
|
||||||
void move_or_copy(
|
|
||||||
const array& in,
|
|
||||||
array& out,
|
|
||||||
const Strides& strides,
|
|
||||||
array::Flags flags,
|
|
||||||
size_t data_size,
|
|
||||||
size_t offset = 0);
|
|
||||||
|
|
||||||
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
|
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
|
||||||
|
|
||||||
void shared_buffer_reshape(
|
void shared_buffer_reshape(
|
||||||
const array& in,
|
const array& in,
|
||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
|
// Like the swapaxes op but safe to call in eval_gpu.
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||||
|
vec.erase(std::next(vec.begin(), index));
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -40,11 +40,15 @@ add_dependencies(mlx cpu_compiled_preamble)
|
|||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
@@ -56,6 +60,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
@@ -65,13 +70,14 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||||
|
|
||||||
if(MLX_BUILD_ACCELERATE)
|
if(MLX_BUILD_ACCELERATE)
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(IOS)
|
if(IOS)
|
||||||
|
|||||||
@@ -2,76 +2,27 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void arange(T start, T next, array& out, size_t size) {
|
void arange(T start, T next, array& out, size_t size, Stream stream) {
|
||||||
auto ptr = out.data<T>();
|
auto ptr = out.data<T>();
|
||||||
auto step_size = next - start;
|
auto step_size = next - start;
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([ptr, start, step_size, size]() mutable {
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
ptr[i] = start;
|
ptr[i] = start;
|
||||||
start += step_size;
|
start += step_size;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void arange(
|
|
||||||
const std::vector<array>& inputs,
|
|
||||||
array& out,
|
|
||||||
double start,
|
|
||||||
double step) {
|
|
||||||
assert(inputs.size() == 0);
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
throw std::runtime_error("Bool type unsupported for arange.");
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
arange<uint8_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
arange<uint16_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
arange<uint32_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
arange<uint64_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
arange<int8_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
arange<int16_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
arange<int32_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
arange<int64_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
arange<float16_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
arange<float>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
arange<double>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
arange<bfloat16_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
arange<complex64_t>(start, start + step, out, out.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -13,19 +14,20 @@ template <typename InT, typename OpT>
|
|||||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||||
auto axis_size = in.shape()[axis];
|
auto axis_size = in.shape()[axis];
|
||||||
auto axis_stride = in.strides()[axis];
|
auto axis_stride = in.strides()[axis];
|
||||||
Strides strides = in.strides();
|
Strides strides = remove_index(in.strides(), axis);
|
||||||
Shape shape = in.shape();
|
Shape shape = remove_index(in.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
auto in_ptr = in.data<InT>();
|
||||||
shape.erase(shape.begin() + axis);
|
auto out_ptr = out.data<uint32_t>();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||||
auto loc = elem_to_loc(i, shape, strides);
|
auto loc = elem_to_loc(i, shape, strides);
|
||||||
auto in_ptr = in.data<InT>() + loc;
|
auto local_in_ptr = in_ptr + loc;
|
||||||
uint32_t ind_v = 0;
|
uint32_t ind_v = 0;
|
||||||
InT v = (*in_ptr);
|
InT v = (*local_in_ptr);
|
||||||
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
|
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||||
op(j, (*in_ptr), &ind_v, &v);
|
op(j, (*local_in_ptr), &ind_v, &v);
|
||||||
}
|
}
|
||||||
out.data<uint32_t>()[i] = ind_v;
|
out_ptr[i] = ind_v;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,8 +66,14 @@ void arg_reduce_dispatch(
|
|||||||
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
reduce_type_ = reduce_type_,
|
||||||
|
axis_ = axis_]() mutable {
|
||||||
switch (in.dtype()) {
|
switch (in.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||||
@@ -110,6 +118,7 @@ void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
11
mlx/backend/cpu/available.cpp
Normal file
11
mlx/backend/cpu/available.cpp
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/available.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
||||||
9
mlx/backend/cpu/available.h
Normal file
9
mlx/backend/cpu/available.h
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
bool is_available();
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
||||||
@@ -8,6 +8,7 @@
|
|||||||
#include "mlx/backend/cpu/binary.h"
|
#include "mlx/backend/cpu/binary.h"
|
||||||
#include "mlx/backend/cpu/binary_ops.h"
|
#include "mlx/backend/cpu/binary_ops.h"
|
||||||
#include "mlx/backend/cpu/binary_two.h"
|
#include "mlx/backend/cpu/binary_two.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@@ -16,51 +17,221 @@ namespace mlx::core {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
||||||
switch (a.dtype()) {
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
binary_op<bool, bool>(a, b, out, op);
|
binary_op<bool, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case uint8:
|
case uint8:
|
||||||
binary_op<uint8_t, bool>(a, b, out, op);
|
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case uint16:
|
case uint16:
|
||||||
binary_op<uint16_t, bool>(a, b, out, op);
|
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case uint32:
|
case uint32:
|
||||||
binary_op<uint32_t, bool>(a, b, out, op);
|
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case uint64:
|
case uint64:
|
||||||
binary_op<uint64_t, bool>(a, b, out, op);
|
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
binary_op<int8_t, bool>(a, b, out, op);
|
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
binary_op<int16_t, bool>(a, b, out, op);
|
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
binary_op<int32_t, bool>(a, b, out, op);
|
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
binary_op<int64_t, bool>(a, b, out, op);
|
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
binary_op<float16_t, bool>(a, b, out, op);
|
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case float32:
|
case float32:
|
||||||
binary_op<float, bool>(a, b, out, op);
|
binary_op<float, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case float64:
|
case float64:
|
||||||
binary_op<double, bool>(a, b, out, op);
|
binary_op<double, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t, bool>(a, b, out, op);
|
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
binary_op<complex64_t, bool>(a, b, out, op);
|
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void comparison_op(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_float(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case float16:
|
||||||
|
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
binary_op<float, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
binary_op<double, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[binary_float] Only supports floating point types.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Op>
|
||||||
|
void binary_int(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
Stream stream) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
|
case bool_:
|
||||||
|
binary_op<bool, Op>(a, b, out, bopt);
|
||||||
|
case uint8:
|
||||||
|
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint16:
|
||||||
|
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint32:
|
||||||
|
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case uint64:
|
||||||
|
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int8:
|
||||||
|
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int16:
|
||||||
|
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int32:
|
||||||
|
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
case int64:
|
||||||
|
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("[binary_int] Type not supported");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -69,7 +240,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Add());
|
binary(a, b, out, detail::Add(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void DivMod::eval_cpu(
|
void DivMod::eval_cpu(
|
||||||
@@ -78,70 +249,89 @@ void DivMod::eval_cpu(
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
auto& out_a = outputs[0];
|
||||||
|
auto& out_b = outputs[1];
|
||||||
|
set_binary_op_output_data(a, b, out_a, bopt);
|
||||||
|
set_binary_op_output_data(a, b, out_b, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out_a);
|
||||||
|
encoder.set_output_array(out_b);
|
||||||
|
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out_a = array::unsafe_weak_copy(out_a),
|
||||||
|
out_b = array::unsafe_weak_copy(out_b),
|
||||||
|
bopt]() mutable {
|
||||||
auto integral_op = [](auto x, auto y) {
|
auto integral_op = [](auto x, auto y) {
|
||||||
return std::make_pair(x / y, x % y);
|
return std::make_pair(x / y, x % y);
|
||||||
};
|
};
|
||||||
auto float_op = [](auto x, auto y) {
|
auto float_op = [](auto x, auto y) {
|
||||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||||
};
|
};
|
||||||
switch (outputs[0].dtype()) {
|
|
||||||
|
switch (out_a.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
binary_op<bool>(a, b, outputs, integral_op);
|
binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
|
||||||
case uint8:
|
case uint8:
|
||||||
binary_op<uint8_t>(a, b, outputs, integral_op);
|
binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||||
break;
|
break;
|
||||||
case uint16:
|
case uint16:
|
||||||
binary_op<uint16_t>(a, b, outputs, integral_op);
|
binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||||
break;
|
break;
|
||||||
case uint32:
|
case uint32:
|
||||||
binary_op<uint32_t>(a, b, outputs, integral_op);
|
binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||||
break;
|
break;
|
||||||
case uint64:
|
case uint64:
|
||||||
binary_op<uint64_t>(a, b, outputs, integral_op);
|
binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||||
break;
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
binary_op<int8_t>(a, b, outputs, integral_op);
|
binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
binary_op<int16_t>(a, b, outputs, integral_op);
|
binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
binary_op<int32_t>(a, b, outputs, integral_op);
|
binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||||
break;
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
binary_op<int64_t>(a, b, outputs, integral_op);
|
binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
binary_op<float16_t>(a, b, outputs, float_op);
|
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
|
||||||
break;
|
break;
|
||||||
case float32:
|
case float32:
|
||||||
binary_op<float>(a, b, outputs, float_op);
|
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
|
||||||
break;
|
break;
|
||||||
case float64:
|
case float64:
|
||||||
binary_op<double>(a, b, outputs, float_op);
|
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
// Should never get here
|
// Should never get here
|
||||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Divide());
|
binary(a, b, out, detail::Divide(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Remainder());
|
binary(a, b, out, detail::Remainder(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -149,181 +339,143 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
if (equal_nan_) {
|
if (equal_nan_) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out, bopt);
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||||
|
b = array::unsafe_weak_copy(b),
|
||||||
|
out = array::unsafe_weak_copy(out),
|
||||||
|
bopt]() mutable {
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case float16:
|
case float16:
|
||||||
binary_op<float16_t, bool>(a, b, out, detail::NaNEqual());
|
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case float32:
|
case float32:
|
||||||
binary_op<float, bool>(a, b, out, detail::NaNEqual());
|
binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case float64:
|
case float64:
|
||||||
binary_op<double, bool>(a, b, out, detail::NaNEqual());
|
binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t, bool>(a, b, out, detail::NaNEqual());
|
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
binary_op<complex64_t, bool>(a, b, out, detail::NaNEqual());
|
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[NanEqual::eval_cpu] Only for floating point types.");
|
"[NanEqual::eval_cpu] Only for floating point types.");
|
||||||
}
|
}
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
comparison_op(a, b, out, detail::Equal());
|
comparison_op(a, b, out, detail::Equal(), stream());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::Greater());
|
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
|
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::Less());
|
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
|
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
switch (out.dtype()) {
|
binary_float(a, b, out, detail::LogAddExp(), stream());
|
||||||
case float16:
|
|
||||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double>(a, b, out, detail::LogAddExp());
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||||
auto& in1 = inputs[0];
|
auto& in1 = inputs[0];
|
||||||
auto& in2 = inputs[1];
|
auto& in2 = inputs[1];
|
||||||
binary(in1, in2, out, detail::LogicalAnd());
|
binary(in1, in2, out, detail::LogicalAnd(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||||
auto& in1 = inputs[0];
|
auto& in1 = inputs[0];
|
||||||
auto& in2 = inputs[1];
|
auto& in2 = inputs[1];
|
||||||
binary(in1, in2, out, detail::LogicalOr());
|
binary(in1, in2, out, detail::LogicalOr(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Maximum());
|
binary(a, b, out, detail::Maximum(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Minimum());
|
binary(a, b, out, detail::Minimum(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Multiply());
|
binary(a, b, out, detail::Multiply(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
|
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Power());
|
binary(a, b, out, detail::Power(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
binary(a, b, out, detail::Subtract());
|
binary(a, b, out, detail::Subtract(), stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& a = inputs[0];
|
auto& a = inputs[0];
|
||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
auto dispatch_type = [&a, &b, &out](auto op) {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool>(a, b, out, op);
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[BitwiseBinary::eval_cpu] Type not supported");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
switch (op_) {
|
switch (op_) {
|
||||||
case BitwiseBinary::And:
|
case BitwiseBinary::And:
|
||||||
dispatch_type(detail::BitwiseAnd());
|
binary_int(a, b, out, detail::BitwiseAnd(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Or:
|
case BitwiseBinary::Or:
|
||||||
dispatch_type(detail::BitwiseOr());
|
binary_int(a, b, out, detail::BitwiseOr(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::Xor:
|
case BitwiseBinary::Xor:
|
||||||
dispatch_type(detail::BitwiseXor());
|
binary_int(a, b, out, detail::BitwiseXor(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::LeftShift:
|
case BitwiseBinary::LeftShift:
|
||||||
dispatch_type(detail::LeftShift());
|
binary_int(a, b, out, detail::LeftShift(), stream());
|
||||||
break;
|
break;
|
||||||
case BitwiseBinary::RightShift:
|
case BitwiseBinary::RightShift:
|
||||||
dispatch_type(detail::RightShift());
|
binary_int(a, b, out, detail::RightShift(), stream());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -332,23 +484,7 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
const auto& a = inputs[0];
|
const auto& a = inputs[0];
|
||||||
const auto& b = inputs[1];
|
const auto& b = inputs[1];
|
||||||
switch (out.dtype()) {
|
binary_float(a, b, out, detail::ArcTan2(), stream());
|
||||||
case float16:
|
|
||||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double>(a, b, out, detail::ArcTan2());
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::runtime_error(
|
|
||||||
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
@@ -14,22 +13,18 @@ namespace mlx::core {
|
|||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
struct VectorScalar {
|
struct VectorScalar {
|
||||||
Op op;
|
|
||||||
|
|
||||||
VectorScalar(Op op_) : op(op_) {}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||||
T scalar = *b;
|
T scalar = *b;
|
||||||
constexpr int N = simd::max_size<T>;
|
constexpr int N = simd::max_size<T>;
|
||||||
while (size >= N) {
|
while (size >= N) {
|
||||||
simd::store(dst, op(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
|
simd::store(dst, Op{}(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
|
||||||
dst += N;
|
dst += N;
|
||||||
a += N;
|
a += N;
|
||||||
size -= N;
|
size -= N;
|
||||||
}
|
}
|
||||||
while (size-- > 0) {
|
while (size-- > 0) {
|
||||||
*dst = op(*a, scalar);
|
*dst = Op{}(*a, scalar);
|
||||||
dst++;
|
dst++;
|
||||||
a++;
|
a++;
|
||||||
}
|
}
|
||||||
@@ -38,22 +33,18 @@ struct VectorScalar {
|
|||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
struct ScalarVector {
|
struct ScalarVector {
|
||||||
Op op;
|
|
||||||
|
|
||||||
ScalarVector(Op op_) : op(op_) {}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||||
T scalar = *a;
|
T scalar = *a;
|
||||||
constexpr int N = simd::max_size<T>;
|
constexpr int N = simd::max_size<T>;
|
||||||
while (size >= N) {
|
while (size >= N) {
|
||||||
simd::store(dst, op(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
|
simd::store(dst, Op{}(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
|
||||||
dst += N;
|
dst += N;
|
||||||
b += N;
|
b += N;
|
||||||
size -= N;
|
size -= N;
|
||||||
}
|
}
|
||||||
while (size-- > 0) {
|
while (size-- > 0) {
|
||||||
*dst = op(scalar, *b);
|
*dst = Op{}(scalar, *b);
|
||||||
dst++;
|
dst++;
|
||||||
b++;
|
b++;
|
||||||
}
|
}
|
||||||
@@ -62,22 +53,18 @@ struct ScalarVector {
|
|||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
struct VectorVector {
|
struct VectorVector {
|
||||||
Op op;
|
|
||||||
|
|
||||||
VectorVector(Op op_) : op(op_) {}
|
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||||
constexpr int N = simd::max_size<T>;
|
constexpr int N = simd::max_size<T>;
|
||||||
while (size >= N) {
|
while (size >= N) {
|
||||||
simd::store(dst, op(simd::load<T, N>(a), simd::load<T, N>(b)));
|
simd::store(dst, Op{}(simd::load<T, N>(a), simd::load<T, N>(b)));
|
||||||
dst += N;
|
dst += N;
|
||||||
a += N;
|
a += N;
|
||||||
b += N;
|
b += N;
|
||||||
size -= N;
|
size -= N;
|
||||||
}
|
}
|
||||||
while (size-- > 0) {
|
while (size-- > 0) {
|
||||||
*dst = op(*a, *b);
|
*dst = Op{}(*a, *b);
|
||||||
dst++;
|
dst++;
|
||||||
a++;
|
a++;
|
||||||
b++;
|
b++;
|
||||||
@@ -90,7 +77,6 @@ void binary_op_dims(
|
|||||||
const T* a,
|
const T* a,
|
||||||
const T* b,
|
const T* b,
|
||||||
U* out,
|
U* out,
|
||||||
Op op,
|
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const Strides& a_strides,
|
const Strides& a_strides,
|
||||||
const Strides& b_strides,
|
const Strides& b_strides,
|
||||||
@@ -104,12 +90,12 @@ void binary_op_dims(
|
|||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
if constexpr (D > 1) {
|
if constexpr (D > 1) {
|
||||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
binary_op_dims<T, U, Op, D - 1, Strided>(
|
||||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||||
} else {
|
} else {
|
||||||
if constexpr (Strided) {
|
if constexpr (Strided) {
|
||||||
op(a, b, out, stride_out);
|
Op{}(a, b, out, stride_out);
|
||||||
} else {
|
} else {
|
||||||
*out = op(*a, *b);
|
*out = Op{}(*a, *b);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out += stride_out;
|
out += stride_out;
|
||||||
@@ -120,66 +106,38 @@ void binary_op_dims(
|
|||||||
|
|
||||||
template <typename T, typename U, bool Strided, typename Op>
|
template <typename T, typename U, bool Strided, typename Op>
|
||||||
void binary_op_dispatch_dims(
|
void binary_op_dispatch_dims(
|
||||||
const array& a,
|
const T* a,
|
||||||
const array& b,
|
const T* b,
|
||||||
array& out,
|
U* out,
|
||||||
Op op,
|
|
||||||
int dim,
|
int dim,
|
||||||
|
int size,
|
||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const Strides& a_strides,
|
const Strides& a_strides,
|
||||||
const Strides& b_strides,
|
const Strides& b_strides,
|
||||||
const Strides& out_strides) {
|
const Strides& out_strides) {
|
||||||
const T* a_ptr = a.data<T>();
|
|
||||||
const T* b_ptr = b.data<T>();
|
|
||||||
U* out_ptr = out.data<U>();
|
|
||||||
switch (dim) {
|
switch (dim) {
|
||||||
case 1:
|
case 1:
|
||||||
binary_op_dims<T, U, Op, 1, Strided>(
|
binary_op_dims<T, U, Op, 1, Strided>(
|
||||||
a_ptr,
|
a, b, out, shape, a_strides, b_strides, out_strides, 0);
|
||||||
b_ptr,
|
|
||||||
out_ptr,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
out_strides,
|
|
||||||
0);
|
|
||||||
return;
|
return;
|
||||||
case 2:
|
case 2:
|
||||||
binary_op_dims<T, U, Op, 2, Strided>(
|
binary_op_dims<T, U, Op, 2, Strided>(
|
||||||
a_ptr,
|
a, b, out, shape, a_strides, b_strides, out_strides, 0);
|
||||||
b_ptr,
|
|
||||||
out_ptr,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
out_strides,
|
|
||||||
0);
|
|
||||||
return;
|
return;
|
||||||
case 3:
|
case 3:
|
||||||
binary_op_dims<T, U, Op, 3, Strided>(
|
binary_op_dims<T, U, Op, 3, Strided>(
|
||||||
a_ptr,
|
a, b, out, shape, a_strides, b_strides, out_strides, 0);
|
||||||
b_ptr,
|
|
||||||
out_ptr,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
out_strides,
|
|
||||||
0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(shape, a_strides, dim - 3);
|
ContiguousIterator a_it(shape, a_strides, dim - 3);
|
||||||
ContiguousIterator b_it(shape, b_strides, dim - 3);
|
ContiguousIterator b_it(shape, b_strides, dim - 3);
|
||||||
auto stride = out_strides[dim - 4];
|
auto stride = out_strides[dim - 4];
|
||||||
for (int64_t elem = 0; elem < a.size(); elem += stride) {
|
for (int64_t elem = 0; elem < size; elem += stride) {
|
||||||
binary_op_dims<T, U, Op, 3, Strided>(
|
binary_op_dims<T, U, Op, 3, Strided>(
|
||||||
a_ptr + a_it.loc,
|
a + a_it.loc,
|
||||||
b_ptr + b_it.loc,
|
b + b_it.loc,
|
||||||
out_ptr + elem,
|
out + elem,
|
||||||
op,
|
|
||||||
shape,
|
shape,
|
||||||
a_strides,
|
a_strides,
|
||||||
b_strides,
|
b_strides,
|
||||||
@@ -191,40 +149,41 @@ void binary_op_dispatch_dims(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
||||||
auto bopt = get_binary_op_type(a, b);
|
|
||||||
set_binary_op_output_data(a, b, out, bopt);
|
|
||||||
|
|
||||||
// The full computation is scalar scalar so call the base op once
|
// The full computation is scalar scalar so call the base op once
|
||||||
|
auto a_ptr = a.data<T>();
|
||||||
|
auto b_ptr = b.data<T>();
|
||||||
|
|
||||||
|
auto out_ptr = out.data<U>();
|
||||||
if (bopt == BinaryOpType::ScalarScalar) {
|
if (bopt == BinaryOpType::ScalarScalar) {
|
||||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
*out_ptr = Op{}(*a_ptr, *b_ptr);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The full computation is scalar vector so delegate to the op
|
// The full computation is scalar vector so delegate to the op
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
ScalarVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The full computation is vector scalar so delegate to the op
|
// The full computation is vector scalar so delegate to the op
|
||||||
if (bopt == BinaryOpType::VectorScalar) {
|
if (bopt == BinaryOpType::VectorScalar) {
|
||||||
VectorScalar{op}(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The full computation is vector vector so delegate to the op
|
// The full computation is vector vector so delegate to the op
|
||||||
if (bopt == BinaryOpType::VectorVector) {
|
if (bopt == BinaryOpType::VectorVector) {
|
||||||
VectorVector{op}(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// General computation so let's try to optimize
|
// General computation so let's try to optimize
|
||||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||||
const auto& a_strides = new_strides[0];
|
auto& a_strides = new_strides[0];
|
||||||
const auto& b_strides = new_strides[1];
|
auto& b_strides = new_strides[1];
|
||||||
const auto& strides = new_strides[2];
|
auto& strides = new_strides[2];
|
||||||
|
|
||||||
// Get the left-most dim such that the array is row contiguous after
|
// Get the left-most dim such that the array is row contiguous after
|
||||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||||
@@ -248,7 +207,8 @@ void binary_op(const array& a, const array& b, array& out, Op op) {
|
|||||||
|
|
||||||
auto ndim = new_shape.size();
|
auto ndim = new_shape.size();
|
||||||
|
|
||||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
// Case 1: LxM and FxM where L and F are broadcastable and M is row
|
||||||
|
// contiguous
|
||||||
int dim = ndim;
|
int dim = ndim;
|
||||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||||
bopt = BinaryOpType::VectorVector;
|
bopt = BinaryOpType::VectorVector;
|
||||||
@@ -275,99 +235,59 @@ void binary_op(const array& a, const array& b, array& out, Op op) {
|
|||||||
|
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::VectorVector:
|
case BinaryOpType::VectorVector:
|
||||||
binary_op_dispatch_dims<T, U, true>(
|
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
|
||||||
a,
|
a_ptr,
|
||||||
b,
|
b_ptr,
|
||||||
out,
|
out_ptr,
|
||||||
VectorVector{op},
|
|
||||||
dim,
|
dim,
|
||||||
|
a.size(),
|
||||||
new_shape,
|
new_shape,
|
||||||
a_strides,
|
a_strides,
|
||||||
b_strides,
|
b_strides,
|
||||||
strides);
|
strides);
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorScalar:
|
case BinaryOpType::VectorScalar:
|
||||||
binary_op_dispatch_dims<T, U, true>(
|
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
|
||||||
a,
|
a_ptr,
|
||||||
b,
|
b_ptr,
|
||||||
out,
|
out_ptr,
|
||||||
VectorScalar{op},
|
|
||||||
dim,
|
dim,
|
||||||
|
a.size(),
|
||||||
new_shape,
|
new_shape,
|
||||||
a_strides,
|
a_strides,
|
||||||
b_strides,
|
b_strides,
|
||||||
strides);
|
strides);
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
binary_op_dispatch_dims<T, U, true>(
|
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
|
||||||
a,
|
a_ptr,
|
||||||
b,
|
b_ptr,
|
||||||
out,
|
out_ptr,
|
||||||
ScalarVector{op},
|
|
||||||
dim,
|
dim,
|
||||||
|
a.size(),
|
||||||
new_shape,
|
new_shape,
|
||||||
a_strides,
|
a_strides,
|
||||||
b_strides,
|
b_strides,
|
||||||
strides);
|
strides);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
binary_op_dispatch_dims<T, U, false>(
|
binary_op_dispatch_dims<T, U, false, Op>(
|
||||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
out_ptr,
|
||||||
|
dim,
|
||||||
|
a.size(),
|
||||||
|
new_shape,
|
||||||
|
a_strides,
|
||||||
|
b_strides,
|
||||||
|
strides);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename Op>
|
template <typename T, typename Op>
|
||||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
||||||
binary_op<T, T>(a, b, out, op);
|
binary_op<T, T, Op>(a, b, out, bopt);
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary(const array& a, const array& b, array& out, Op op) {
|
|
||||||
switch (out.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t>(a, b, out, op);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -58,14 +58,14 @@ void binary_op_dispatch_dims(
|
|||||||
Op op) {
|
Op op) {
|
||||||
auto [shape, strides] = collapse_contiguous_dims(
|
auto [shape, strides] = collapse_contiguous_dims(
|
||||||
a.shape(), {a.strides(), b.strides(), out_a.strides()});
|
a.shape(), {a.strides(), b.strides(), out_a.strides()});
|
||||||
const auto& a_strides = strides[0];
|
|
||||||
const auto& b_strides = strides[1];
|
|
||||||
const auto& out_strides = strides[2];
|
|
||||||
const T* a_ptr = a.data<T>();
|
const T* a_ptr = a.data<T>();
|
||||||
const T* b_ptr = b.data<T>();
|
const T* b_ptr = b.data<T>();
|
||||||
U* out_a_ptr = out_a.data<U>();
|
U* out_a_ptr = out_a.data<U>();
|
||||||
U* out_b_ptr = out_b.data<U>();
|
U* out_b_ptr = out_b.data<U>();
|
||||||
|
|
||||||
|
const auto& a_strides = strides[0];
|
||||||
|
const auto& b_strides = strides[1];
|
||||||
|
const auto& out_strides = strides[2];
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
switch (ndim) {
|
switch (ndim) {
|
||||||
case 1:
|
case 1:
|
||||||
@@ -120,14 +120,10 @@ template <typename T, typename U = T, typename Op>
|
|||||||
void binary_op(
|
void binary_op(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
std::vector<array>& outputs,
|
array& out_a,
|
||||||
Op op) {
|
array& out_b,
|
||||||
auto bopt = get_binary_op_type(a, b);
|
Op op,
|
||||||
auto& out_a = outputs[0];
|
BinaryOpType bopt) {
|
||||||
auto& out_b = outputs[1];
|
|
||||||
set_binary_op_output_data(a, b, out_a, bopt);
|
|
||||||
set_binary_op_output_data(a, b, out_b, bopt);
|
|
||||||
|
|
||||||
// The full computation is scalar scalar so call the base op once
|
// The full computation is scalar scalar so call the base op once
|
||||||
if (bopt == BinaryOpType::General) {
|
if (bopt == BinaryOpType::General) {
|
||||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
|
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
|
||||||
@@ -141,14 +137,14 @@ void binary_op(
|
|||||||
if (bopt == BinaryOpType::ScalarScalar) {
|
if (bopt == BinaryOpType::ScalarScalar) {
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||||
for (size_t i = 0; i < b.size(); ++i) {
|
for (size_t i = 0; i < b.data_size(); ++i) {
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||||
out_a_ptr++;
|
out_a_ptr++;
|
||||||
out_b_ptr++;
|
out_b_ptr++;
|
||||||
b_ptr++;
|
b_ptr++;
|
||||||
}
|
}
|
||||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||||
for (size_t i = 0; i < a.size(); ++i) {
|
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||||
out_a_ptr++;
|
out_a_ptr++;
|
||||||
out_b_ptr++;
|
out_b_ptr++;
|
||||||
@@ -165,58 +161,6 @@ void binary_op(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op>
|
|
||||||
void binary(
|
|
||||||
const array& a,
|
|
||||||
const array& b,
|
|
||||||
std::vector<array>& outputs,
|
|
||||||
Op op) {
|
|
||||||
switch (outputs[0].dtype()) {
|
|
||||||
case bool_:
|
|
||||||
binary_op<bool>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
binary_op<uint8_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case uint16:
|
|
||||||
binary_op<uint16_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
binary_op<uint32_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
binary_op<uint64_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case int8:
|
|
||||||
binary_op<int8_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case int16:
|
|
||||||
binary_op<int16_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
binary_op<int32_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
binary_op<int64_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
binary_op<float16_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
binary_op<float>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
binary_op<double>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
binary_op<bfloat16_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
case complex64:
|
|
||||||
binary_op<complex64_t>(a, b, outputs, op);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/lapack.h"
|
#include "mlx/backend/cpu/lapack.h"
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -9,7 +10,7 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
||||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||||
// the matrix should be symmetric:
|
// the matrix should be symmetric:
|
||||||
// (A)ᵀ = A
|
// (A)ᵀ = A
|
||||||
@@ -17,20 +18,22 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
|||||||
// triangular matrix, so uplo is the opposite of what we would expect from
|
// triangular matrix, so uplo is the opposite of what we would expect from
|
||||||
// upper
|
// upper
|
||||||
|
|
||||||
char uplo = (upper) ? 'L' : 'U';
|
|
||||||
|
|
||||||
// The decomposition is computed in place, so just copy the input to the
|
// The decomposition is computed in place, so just copy the input to the
|
||||||
// output.
|
// output.
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
factor,
|
factor,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
stream);
|
||||||
const int N = a.shape(-1);
|
|
||||||
const size_t num_matrices = a.size() / (N * N);
|
|
||||||
|
|
||||||
T* matrix = factor.data<T>();
|
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_output_array(factor);
|
||||||
|
encoder.dispatch([matrix = factor.data<T>(),
|
||||||
|
upper,
|
||||||
|
N = a.shape(-1),
|
||||||
|
size = a.size()]() mutable {
|
||||||
|
char uplo = (upper) ? 'L' : 'U';
|
||||||
|
size_t num_matrices = size / (N * N);
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
// Compute Cholesky factorization.
|
// Compute Cholesky factorization.
|
||||||
int info;
|
int info;
|
||||||
@@ -46,7 +49,7 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
|||||||
// to catch errors from the implementation we should throw.
|
// to catch errors from the implementation we should throw.
|
||||||
if (info < 0) {
|
if (info < 0) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[cholesky] Cholesky decomposition failed with error code "
|
msg << "[Cholesky::eval_cpu] Cholesky decomposition failed with error code "
|
||||||
<< info;
|
<< info;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
@@ -62,15 +65,16 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
|||||||
matrix += N;
|
matrix += N;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
|
void Cholesky::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||||
switch (inputs[0].dtype()) {
|
switch (inputs[0].dtype()) {
|
||||||
case float32:
|
case float32:
|
||||||
cholesky_impl<float>(inputs[0], output, upper_);
|
cholesky_impl<float>(inputs[0], output, upper_, stream());
|
||||||
break;
|
break;
|
||||||
case float64:
|
case float64:
|
||||||
cholesky_impl<double>(inputs[0], output, upper_);
|
cholesky_impl<double>(inputs[0], output, upper_, stream());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/cpu/compiled_preamble.h"
|
#include "mlx/backend/cpu/compiled_preamble.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/jit_compiler.h"
|
#include "mlx/backend/cpu/jit_compiler.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
@@ -39,7 +40,10 @@ struct CompilerCache {
|
|||||||
std::shared_mutex mtx;
|
std::shared_mutex mtx;
|
||||||
};
|
};
|
||||||
|
|
||||||
static CompilerCache cache{};
|
static CompilerCache& cache() {
|
||||||
|
static CompilerCache cache_;
|
||||||
|
return cache_;
|
||||||
|
};
|
||||||
|
|
||||||
// GPU compile is always available if the GPU is available and since we are in
|
// GPU compile is always available if the GPU is available and since we are in
|
||||||
// this file CPU compile is also available.
|
// this file CPU compile is also available.
|
||||||
@@ -55,14 +59,16 @@ void* compile(
|
|||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::function<std::string(void)>& source_builder) {
|
const std::function<std::string(void)>& source_builder) {
|
||||||
{
|
{
|
||||||
std::shared_lock lock(cache.mtx);
|
std::shared_lock lock(cache().mtx);
|
||||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
if (auto it = cache().kernels.find(kernel_name);
|
||||||
|
it != cache().kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_lock lock(cache.mtx);
|
std::unique_lock lock(cache().mtx);
|
||||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
if (auto it = cache().kernels.find(kernel_name);
|
||||||
|
it != cache().kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
std::string source_code = source_builder();
|
std::string source_code = source_builder();
|
||||||
@@ -119,10 +125,10 @@ void* compile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load library
|
// load library
|
||||||
cache.libs.emplace_back(shared_lib_path);
|
cache().libs.emplace_back(shared_lib_path);
|
||||||
|
|
||||||
// Load function
|
// Load function
|
||||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
||||||
if (!fun) {
|
if (!fun) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||||
@@ -130,7 +136,7 @@ void* compile(
|
|||||||
<< dlerror();
|
<< dlerror();
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
cache.kernels.insert({kernel_name, fun});
|
cache().kernels.insert({kernel_name, fun});
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,18 +146,9 @@ inline void build_kernel(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& tape,
|
const std::vector<array>& tape,
|
||||||
const std::unordered_set<uintptr_t>& constant_ids,
|
const std::function<bool(size_t)>& is_constant,
|
||||||
bool contiguous,
|
bool contiguous,
|
||||||
int ndim) {
|
int ndim) {
|
||||||
// All outputs should have the exact same shape and will be row contiguous
|
|
||||||
auto output_shape = outputs[0].shape();
|
|
||||||
auto output_strides = outputs[0].strides();
|
|
||||||
|
|
||||||
// Constants are scalars that are captured by value and cannot change
|
|
||||||
auto is_constant = [&constant_ids](const array& x) {
|
|
||||||
return constant_ids.find(x.id()) != constant_ids.end();
|
|
||||||
};
|
|
||||||
|
|
||||||
NodeNamer namer;
|
NodeNamer namer;
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
@@ -164,14 +161,15 @@ inline void build_kernel(
|
|||||||
|
|
||||||
// Add the input arguments
|
// Add the input arguments
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
auto& xname = namer.get_name(x);
|
|
||||||
|
|
||||||
// Skip constants from the input list
|
// Skip constants from the input list
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto& x = inputs[i];
|
||||||
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
auto tstr = get_type_string(x.dtype());
|
auto tstr = get_type_string(x.dtype());
|
||||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||||
<< "];" << std::endl;
|
<< "];" << std::endl;
|
||||||
@@ -205,10 +203,11 @@ inline void build_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read the inputs in tmps
|
// Read the inputs in tmps
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
|
const auto& x = inputs[i];
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
|
|
||||||
if (is_constant(x)) {
|
if (is_constant(i)) {
|
||||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||||
print_constant(os, x);
|
print_constant(os, x);
|
||||||
os << ";" << std::endl;
|
os << ";" << std::endl;
|
||||||
@@ -232,7 +231,7 @@ inline void build_kernel(
|
|||||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||||
} else {
|
} else {
|
||||||
x.primitive().print(os);
|
os << x.primitive().name();
|
||||||
os << "()(";
|
os << "()(";
|
||||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||||
@@ -258,8 +257,9 @@ inline void build_kernel(
|
|||||||
} else {
|
} else {
|
||||||
for (int d = ndim - 1; d >= 0; --d) {
|
for (int d = ndim - 1; d >= 0; --d) {
|
||||||
// Update pointers
|
// Update pointers
|
||||||
for (auto& x : inputs) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
if (is_constant(x) || is_scalar(x)) {
|
const auto& x = inputs[i];
|
||||||
|
if (is_constant(i) || is_scalar(x)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& xname = namer.get_name(x);
|
auto& xname = namer.get_name(x);
|
||||||
@@ -281,63 +281,45 @@ inline void build_kernel(
|
|||||||
void Compiled::eval_cpu(
|
void Compiled::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
if (kernel_lib_.empty()) {
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
|
||||||
|
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||||
|
// handle all broadcasting.
|
||||||
|
auto [contiguous, shape, strides] =
|
||||||
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
|
|
||||||
|
// Force allocating shape/strides on heap so we can take their data() first
|
||||||
|
// and then std::move them.
|
||||||
|
// TODO: Refactor code to avoid heap allocation.
|
||||||
|
shape.grow();
|
||||||
|
for (auto& s : strides) {
|
||||||
|
s.grow();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Figure out which kernel we are using
|
// Collect function input arguments.
|
||||||
auto& shape = outputs[0].shape();
|
|
||||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
|
||||||
|
|
||||||
// Handle all broadcasting and collect function input arguments
|
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
std::vector<std::vector<size_t>> strides;
|
int strides_index = 1;
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
// Skip constants.
|
if (is_constant_(i)) {
|
||||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto& x = inputs[i];
|
const auto& x = inputs[i];
|
||||||
|
encoder.set_input_array(x);
|
||||||
args.push_back((void*)x.data<void>());
|
args.push_back((void*)x.data<void>());
|
||||||
|
if (!contiguous && !is_scalar(x)) {
|
||||||
if (contiguous || is_scalar(x)) {
|
args.push_back(strides[strides_index++].data());
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast the input to the output shape.
|
|
||||||
std::vector<size_t> xstrides;
|
|
||||||
int j = 0;
|
|
||||||
for (; j < shape.size() - x.ndim(); j++) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
|
||||||
if (x.shape(i) == 1) {
|
|
||||||
if (shape[j] == 1) {
|
|
||||||
xstrides.push_back(outputs[0].strides()[j]);
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(0);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
xstrides.push_back(x.strides()[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
strides.push_back(std::move(xstrides));
|
|
||||||
args.push_back(strides.back().data());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the kernel name from the lib
|
// Get the kernel name from the lib
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
kernel_name += std::to_string(shape.size());
|
kernel_name += std::to_string(ndim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the function
|
// Get the function
|
||||||
auto fn_ptr = compile(kernel_name, [&]() {
|
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << get_kernel_preamble() << std::endl;
|
kernel << get_kernel_preamble() << std::endl;
|
||||||
kernel << "extern \"C\" {" << std::endl;
|
kernel << "extern \"C\" {" << std::endl;
|
||||||
@@ -347,7 +329,7 @@ void Compiled::eval_cpu(
|
|||||||
inputs_,
|
inputs_,
|
||||||
outputs_,
|
outputs_,
|
||||||
tape_,
|
tape_,
|
||||||
constant_ids_,
|
is_constant_,
|
||||||
contiguous,
|
contiguous,
|
||||||
ndim);
|
ndim);
|
||||||
// Close extern "C"
|
// Close extern "C"
|
||||||
@@ -355,19 +337,22 @@ void Compiled::eval_cpu(
|
|||||||
return kernel.str();
|
return kernel.str();
|
||||||
});
|
});
|
||||||
|
|
||||||
compiled_allocate_outputs(
|
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
|
||||||
|
|
||||||
for (auto& x : outputs) {
|
for (auto& x : outputs) {
|
||||||
args.push_back(x.data<void>());
|
args.push_back(x.data<void>());
|
||||||
|
encoder.set_output_array(x);
|
||||||
}
|
}
|
||||||
if (!contiguous) {
|
if (!contiguous) {
|
||||||
args.push_back((void*)outputs[0].shape().data());
|
args.push_back((void*)shape.data());
|
||||||
} else {
|
} else {
|
||||||
args.push_back((void*)outputs[0].data_size());
|
args.push_back((void*)outputs[0].data_size());
|
||||||
}
|
}
|
||||||
auto fun = (void (*)(void**))fn_ptr;
|
auto fun = (void (*)(void**))fn_ptr;
|
||||||
fun(args.data());
|
encoder.dispatch([fun,
|
||||||
|
args = std::move(args),
|
||||||
|
strides = std::move(strides),
|
||||||
|
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/simd/simd.h"
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -13,19 +14,19 @@ namespace {
|
|||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
void copy_single(const array& src, array& dst) {
|
void copy_single(const array& src, array& dst) {
|
||||||
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
|
auto src_ptr = src.data<SrcT>();
|
||||||
auto dst_ptr = dst.data<DstT>();
|
auto dst_ptr = dst.data<DstT>();
|
||||||
for (int i = 0; i < dst.size(); ++i) {
|
auto size = dst.size();
|
||||||
dst_ptr[i] = val;
|
auto val = static_cast<DstT>(src_ptr[0]);
|
||||||
}
|
std::fill_n(dst_ptr, size, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
void copy_vector(const array& src, array& dst) {
|
void copy_vector(const array& src, array& dst) {
|
||||||
auto src_ptr = src.data<SrcT>();
|
auto src_ptr = src.data<SrcT>();
|
||||||
auto dst_ptr = dst.data<DstT>();
|
auto dst_ptr = dst.data<DstT>();
|
||||||
size_t size = src.data_size();
|
auto size = src.data_size();
|
||||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
std::copy(src_ptr, src_ptr + size, dst_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT, int D>
|
template <typename SrcT, typename DstT, int D>
|
||||||
@@ -60,36 +61,57 @@ void copy_general_general(
|
|||||||
const Strides& i_strides,
|
const Strides& i_strides,
|
||||||
const Strides& o_strides,
|
const Strides& o_strides,
|
||||||
int64_t i_offset,
|
int64_t i_offset,
|
||||||
int64_t o_offset) {
|
int64_t o_offset,
|
||||||
if (data_shape.empty()) {
|
const std::optional<array>& dynamic_i_offset,
|
||||||
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
|
const std::optional<array>& dynamic_o_offset) {
|
||||||
|
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||||
|
auto i_offset_ptr =
|
||||||
|
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
|
||||||
|
auto o_offset_ptr =
|
||||||
|
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
|
||||||
|
auto size = src.size();
|
||||||
|
if (data_shape.empty()) {
|
||||||
|
auto val = static_cast<DstT>(*src_ptr);
|
||||||
*dst_ptr = val;
|
*dst_ptr = val;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto [shape, strides] =
|
auto [shape, strides] =
|
||||||
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
||||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
|
||||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
|
if (ndim < 3) {
|
||||||
|
if (i_offset_ptr) {
|
||||||
|
src_ptr += i_offset_ptr[0];
|
||||||
|
}
|
||||||
|
if (o_offset_ptr) {
|
||||||
|
dst_ptr += o_offset_ptr[0];
|
||||||
|
}
|
||||||
|
|
||||||
if (ndim == 1) {
|
if (ndim == 1) {
|
||||||
copy_dims<SrcT, DstT, 1>(
|
copy_dims<SrcT, DstT, 1>(
|
||||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||||
return;
|
|
||||||
} else if (ndim == 2) {
|
} else if (ndim == 2) {
|
||||||
copy_dims<SrcT, DstT, 2>(
|
copy_dims<SrcT, DstT, 2>(
|
||||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||||
return;
|
|
||||||
} else if (ndim == 3) {
|
} else if (ndim == 3) {
|
||||||
copy_dims<SrcT, DstT, 3>(
|
copy_dims<SrcT, DstT, 3>(
|
||||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (i_offset_ptr) {
|
||||||
|
src_ptr += i_offset_ptr[0];
|
||||||
|
}
|
||||||
|
if (o_offset_ptr) {
|
||||||
|
dst_ptr += o_offset_ptr[0];
|
||||||
|
}
|
||||||
|
|
||||||
ContiguousIterator in(shape, strides[0], ndim - 3);
|
ContiguousIterator in(shape, strides[0], ndim - 3);
|
||||||
ContiguousIterator out(shape, strides[1], ndim - 3);
|
ContiguousIterator out(shape, strides[1], ndim - 3);
|
||||||
auto stride = std::accumulate(
|
auto stride = std::accumulate(
|
||||||
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
||||||
for (int64_t elem = 0; elem < src.size(); elem += stride) {
|
for (int64_t elem = 0; elem < size; elem += stride) {
|
||||||
copy_dims<SrcT, DstT, 3>(
|
copy_dims<SrcT, DstT, 3>(
|
||||||
src_ptr + in.loc,
|
src_ptr + in.loc,
|
||||||
dst_ptr + out.loc,
|
dst_ptr + out.loc,
|
||||||
@@ -105,7 +127,15 @@ void copy_general_general(
|
|||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
inline void copy_general_general(const array& src, array& dst) {
|
inline void copy_general_general(const array& src, array& dst) {
|
||||||
copy_general_general<SrcT, DstT>(
|
copy_general_general<SrcT, DstT>(
|
||||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
src,
|
||||||
|
dst,
|
||||||
|
src.shape(),
|
||||||
|
src.strides(),
|
||||||
|
dst.strides(),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
@@ -116,7 +146,9 @@ void copy_general(
|
|||||||
const Strides& i_strides,
|
const Strides& i_strides,
|
||||||
const Strides&,
|
const Strides&,
|
||||||
int64_t i_offset,
|
int64_t i_offset,
|
||||||
int64_t o_offset) {
|
int64_t o_offset,
|
||||||
|
const std::optional<array>& dynamic_i_offset,
|
||||||
|
const std::optional<array>& dynamic_o_offset) {
|
||||||
copy_general_general<SrcT, DstT>(
|
copy_general_general<SrcT, DstT>(
|
||||||
src,
|
src,
|
||||||
dst,
|
dst,
|
||||||
@@ -124,7 +156,9 @@ void copy_general(
|
|||||||
i_strides,
|
i_strides,
|
||||||
make_contiguous_strides(data_shape),
|
make_contiguous_strides(data_shape),
|
||||||
i_offset,
|
i_offset,
|
||||||
o_offset);
|
o_offset,
|
||||||
|
dynamic_i_offset,
|
||||||
|
dynamic_o_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
@@ -136,7 +170,9 @@ inline void copy_general(const array& src, array& dst) {
|
|||||||
src.strides(),
|
src.strides(),
|
||||||
make_contiguous_strides(src.shape()),
|
make_contiguous_strides(src.shape()),
|
||||||
0,
|
0,
|
||||||
0);
|
0,
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT, typename... Args>
|
template <typename SrcT, typename DstT, typename... Args>
|
||||||
@@ -259,38 +295,34 @@ inline void copy_inplace_dispatch(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
void copy_cpu_inplace(
|
||||||
copy_inplace_dispatch(src, dst, ctype);
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
CopyType ctype,
|
||||||
|
Stream stream) {
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(src);
|
||||||
|
encoder.set_output_array(dst);
|
||||||
|
encoder.dispatch(
|
||||||
|
[src = array::unsafe_weak_copy(src),
|
||||||
|
dst = array::unsafe_weak_copy(dst),
|
||||||
|
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy(const array& src, array& dst, CopyType ctype) {
|
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||||
// Allocate the output
|
bool donated = set_copy_output_data(src, dst, ctype);
|
||||||
switch (ctype) {
|
if (donated && src.dtype() == dst.dtype()) {
|
||||||
case CopyType::Vector:
|
// If the output has the same type as the input then there is nothing to
|
||||||
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
|
// copy, just use the buffer.
|
||||||
dst.copy_shared_buffer(src);
|
return;
|
||||||
} else {
|
|
||||||
auto size = src.data_size();
|
|
||||||
dst.set_data(
|
|
||||||
allocator::malloc_or_wait(size * dst.itemsize()),
|
|
||||||
size,
|
|
||||||
src.strides(),
|
|
||||||
src.flags());
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case CopyType::Scalar:
|
|
||||||
case CopyType::General:
|
|
||||||
case CopyType::GeneralGeneral:
|
|
||||||
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
if (ctype == CopyType::GeneralGeneral) {
|
if (ctype == CopyType::GeneralGeneral) {
|
||||||
ctype = CopyType::General;
|
ctype = CopyType::General;
|
||||||
}
|
}
|
||||||
copy_inplace(src, dst, ctype);
|
copy_cpu_inplace(src, dst, ctype, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_inplace(
|
void copy_cpu_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const Shape& data_shape,
|
const Shape& data_shape,
|
||||||
@@ -298,7 +330,31 @@ void copy_inplace(
|
|||||||
const Strides& o_strides,
|
const Strides& o_strides,
|
||||||
int64_t i_offset,
|
int64_t i_offset,
|
||||||
int64_t o_offset,
|
int64_t o_offset,
|
||||||
CopyType ctype) {
|
CopyType ctype,
|
||||||
|
Stream stream,
|
||||||
|
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
|
||||||
|
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(src);
|
||||||
|
encoder.set_output_array(dst);
|
||||||
|
auto weak_copy_if_set = [](auto x) -> std::optional<array> {
|
||||||
|
if (x) {
|
||||||
|
return array::unsafe_weak_copy(*x);
|
||||||
|
} else {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
encoder.dispatch(
|
||||||
|
[src = array::unsafe_weak_copy(src),
|
||||||
|
dst = array::unsafe_weak_copy(dst),
|
||||||
|
data_shape,
|
||||||
|
i_strides,
|
||||||
|
o_strides,
|
||||||
|
i_offset,
|
||||||
|
o_offset,
|
||||||
|
ctype,
|
||||||
|
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
|
||||||
|
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
case CopyType::General:
|
case CopyType::General:
|
||||||
case CopyType::GeneralGeneral:
|
case CopyType::GeneralGeneral:
|
||||||
@@ -310,12 +366,21 @@ void copy_inplace(
|
|||||||
i_strides,
|
i_strides,
|
||||||
o_strides,
|
o_strides,
|
||||||
i_offset,
|
i_offset,
|
||||||
o_offset);
|
o_offset,
|
||||||
|
dynamic_i_offset,
|
||||||
|
dynamic_o_offset);
|
||||||
break;
|
break;
|
||||||
case CopyType::Scalar:
|
case CopyType::Scalar:
|
||||||
case CopyType::Vector:
|
case CopyType::Vector:
|
||||||
copy_inplace_dispatch(src, dst, ctype);
|
copy_inplace_dispatch(src, dst, ctype);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
array contiguous_copy_cpu(const array& arr, Stream stream) {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||||
|
return arr_copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,16 +2,22 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void copy(const array& src, array& dst, CopyType ctype);
|
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
void copy_cpu_inplace(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
CopyType ctype,
|
||||||
|
Stream stream);
|
||||||
|
|
||||||
void copy_inplace(
|
void copy_cpu_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const Shape& data_shape,
|
const Shape& data_shape,
|
||||||
@@ -19,6 +25,12 @@ void copy_inplace(
|
|||||||
const Strides& o_strides,
|
const Strides& o_strides,
|
||||||
int64_t i_offset,
|
int64_t i_offset,
|
||||||
int64_t o_offset,
|
int64_t o_offset,
|
||||||
CopyType ctype);
|
CopyType ctype,
|
||||||
|
Stream stream,
|
||||||
|
const std::optional<array>& dynamic_i_offset = std::nullopt,
|
||||||
|
const std::optional<array>& dynamic_o_offset = std::nullopt);
|
||||||
|
|
||||||
|
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||||
|
array contiguous_copy_cpu(const array& arr, Stream stream);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
98
mlx/backend/cpu/distributed.cpp
Normal file
98
mlx/backend/cpu/distributed.cpp
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/distributed/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
|
std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
||||||
|
if (arr.flags().row_contiguous) {
|
||||||
|
return {arr, false};
|
||||||
|
} else {
|
||||||
|
return {contiguous_copy_cpu(arr, stream), true};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void AllReduce::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto donate_or_copy = [s = stream()](const array& in, array& out) {
|
||||||
|
if (in.flags().row_contiguous) {
|
||||||
|
if (in.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
}
|
||||||
|
return in;
|
||||||
|
} else {
|
||||||
|
array arr_copy = contiguous_copy_cpu(in, s);
|
||||||
|
out.copy_shared_buffer(arr_copy);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto in = donate_or_copy(inputs[0], outputs[0]);
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case Sum:
|
||||||
|
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
||||||
|
break;
|
||||||
|
case Max:
|
||||||
|
distributed::detail::all_max(group(), in, outputs[0], stream());
|
||||||
|
break;
|
||||||
|
case Min:
|
||||||
|
distributed::detail::all_min(group(), in, outputs[0], stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Only all reduce sum, min and max are supported for now");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllGather::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||||
|
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||||
|
distributed::detail::all_gather(group(), in, outputs[0], stream());
|
||||||
|
if (copied) {
|
||||||
|
auto& enc = cpu::get_command_encoder(stream());
|
||||||
|
enc.add_temporary(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Send::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||||
|
distributed::detail::send(group(), in, dst_, stream());
|
||||||
|
outputs[0].copy_shared_buffer(inputs[0]);
|
||||||
|
if (copied) {
|
||||||
|
auto& enc = cpu::get_command_encoder(stream());
|
||||||
|
enc.add_temporary(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Recv::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||||
|
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed
|
||||||
174
mlx/backend/cpu/eig.cpp
Normal file
174
mlx/backend/cpu/eig.cpp
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/backend/cpu/lapack.h"
|
||||||
|
#include "mlx/linalg.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void eig_impl(
|
||||||
|
array& a,
|
||||||
|
array& vectors,
|
||||||
|
array& values,
|
||||||
|
bool compute_eigenvectors,
|
||||||
|
Stream stream) {
|
||||||
|
using OT = std::complex<T>;
|
||||||
|
auto a_ptr = a.data<T>();
|
||||||
|
auto eig_ptr = values.data<OT>();
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_output_array(values);
|
||||||
|
OT* vec_ptr = nullptr;
|
||||||
|
if (compute_eigenvectors) {
|
||||||
|
encoder.set_output_array(vectors);
|
||||||
|
vec_ptr = vectors.data<OT>();
|
||||||
|
}
|
||||||
|
encoder.dispatch([a_ptr,
|
||||||
|
vec_ptr,
|
||||||
|
eig_ptr,
|
||||||
|
compute_eigenvectors,
|
||||||
|
N = vectors.shape(-1),
|
||||||
|
size = vectors.size()]() mutable {
|
||||||
|
// Work query
|
||||||
|
char jobr = 'N';
|
||||||
|
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||||
|
int n_vecs_r = 1;
|
||||||
|
int n_vecs_l = compute_eigenvectors ? N : 1;
|
||||||
|
int lwork = -1;
|
||||||
|
int info;
|
||||||
|
{
|
||||||
|
T work;
|
||||||
|
int iwork;
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
||||||
|
auto vec_tmp_data =
|
||||||
|
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
||||||
|
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||||
|
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||||
|
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||||
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
|
geev<T>(
|
||||||
|
&jobl,
|
||||||
|
&jobr,
|
||||||
|
&N,
|
||||||
|
a_ptr,
|
||||||
|
&N,
|
||||||
|
eig_tmp,
|
||||||
|
eig_tmp + N,
|
||||||
|
vec_tmp,
|
||||||
|
&n_vecs_l,
|
||||||
|
nullptr,
|
||||||
|
&n_vecs_r,
|
||||||
|
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
&info);
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
||||||
|
}
|
||||||
|
if (vec_ptr) {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
if (eig_ptr[i].imag() != 0) {
|
||||||
|
// This vector and the next are a pair
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vec_ptr[i * N + j] = {
|
||||||
|
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
||||||
|
vec_ptr[(i + 1) * N + j] = {
|
||||||
|
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < N; ++j) {
|
||||||
|
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vec_ptr += N * N;
|
||||||
|
}
|
||||||
|
a_ptr += N * N;
|
||||||
|
eig_ptr += N;
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream msg;
|
||||||
|
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
|
<< info;
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
encoder.add_temporary(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void Eig::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
const auto& a = inputs[0];
|
||||||
|
auto& values = outputs[0];
|
||||||
|
|
||||||
|
auto vectors = compute_eigenvectors_
|
||||||
|
? outputs[1]
|
||||||
|
: array(a.shape(), complex64, nullptr, {});
|
||||||
|
|
||||||
|
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
||||||
|
copy_cpu(
|
||||||
|
a,
|
||||||
|
a_copy,
|
||||||
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
stream());
|
||||||
|
|
||||||
|
values.set_data(allocator::malloc(values.nbytes()));
|
||||||
|
|
||||||
|
if (compute_eigenvectors_) {
|
||||||
|
// Set the strides and flags so the eigenvectors
|
||||||
|
// are in the columns of the output
|
||||||
|
auto flags = vectors.flags();
|
||||||
|
auto strides = vectors.strides();
|
||||||
|
auto ndim = a.ndim();
|
||||||
|
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||||
|
|
||||||
|
if (a.size() > 1) {
|
||||||
|
flags.row_contiguous = false;
|
||||||
|
if (ndim > 2) {
|
||||||
|
flags.col_contiguous = false;
|
||||||
|
} else {
|
||||||
|
flags.col_contiguous = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vectors.set_data(
|
||||||
|
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
|
||||||
|
}
|
||||||
|
switch (a.dtype()) {
|
||||||
|
case float32:
|
||||||
|
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/lapack.h"
|
#include "mlx/backend/cpu/lapack.h"
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
@@ -11,28 +12,30 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
template <typename T, class Enable = void>
|
||||||
|
struct EighWork {};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void eigh_impl(
|
struct EighWork<
|
||||||
array& vectors,
|
T,
|
||||||
array& values,
|
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||||
const std::string& uplo,
|
using R = T;
|
||||||
bool compute_eigenvectors) {
|
|
||||||
auto vec_ptr = vectors.data<T>();
|
|
||||||
auto eig_ptr = values.data<T>();
|
|
||||||
|
|
||||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
char jobz;
|
||||||
auto N = vectors.shape(-1);
|
char uplo;
|
||||||
|
int N;
|
||||||
// Work query
|
int lwork;
|
||||||
int lwork = -1;
|
int liwork;
|
||||||
int liwork = -1;
|
|
||||||
int info;
|
int info;
|
||||||
{
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EighWork(char jobz_, char uplo_, int N_)
|
||||||
|
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
||||||
T work;
|
T work;
|
||||||
int iwork;
|
int iwork;
|
||||||
syevd<T>(
|
syevd<T>(
|
||||||
&jobz,
|
&jobz,
|
||||||
uplo.c_str(),
|
&uplo,
|
||||||
&N,
|
&N,
|
||||||
nullptr,
|
nullptr,
|
||||||
&N,
|
&N,
|
||||||
@@ -44,32 +47,139 @@ void eigh_impl(
|
|||||||
&info);
|
&info);
|
||||||
lwork = static_cast<int>(work);
|
lwork = static_cast<int>(work);
|
||||||
liwork = iwork;
|
liwork = iwork;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
void run(T* vectors, T* values) {
|
||||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
|
||||||
for (size_t i = 0; i < vectors.size() / (N * N); ++i) {
|
|
||||||
syevd<T>(
|
syevd<T>(
|
||||||
&jobz,
|
&jobz,
|
||||||
uplo.c_str(),
|
&uplo,
|
||||||
&N,
|
&N,
|
||||||
vec_ptr,
|
vectors,
|
||||||
&N,
|
&N,
|
||||||
eig_ptr,
|
values,
|
||||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
&lwork,
|
&lwork,
|
||||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
||||||
&liwork,
|
&liwork,
|
||||||
&info);
|
&info);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct EighWork<std::complex<float>> {
|
||||||
|
using T = std::complex<float>;
|
||||||
|
using R = float;
|
||||||
|
|
||||||
|
char jobz;
|
||||||
|
char uplo;
|
||||||
|
int N;
|
||||||
|
int lwork;
|
||||||
|
int lrwork;
|
||||||
|
int liwork;
|
||||||
|
int info;
|
||||||
|
std::vector<array::Data> buffers;
|
||||||
|
|
||||||
|
EighWork(char jobz_, char uplo_, int N_)
|
||||||
|
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
||||||
|
T work;
|
||||||
|
R rwork;
|
||||||
|
int iwork;
|
||||||
|
heevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&N,
|
||||||
|
nullptr,
|
||||||
|
&work,
|
||||||
|
&lwork,
|
||||||
|
&rwork,
|
||||||
|
&lrwork,
|
||||||
|
&iwork,
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
lwork = static_cast<int>(work.real());
|
||||||
|
lrwork = static_cast<int>(rwork);
|
||||||
|
liwork = iwork;
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||||
|
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run(T* vectors, R* values) {
|
||||||
|
heevd<T>(
|
||||||
|
&jobz,
|
||||||
|
&uplo,
|
||||||
|
&N,
|
||||||
|
vectors,
|
||||||
|
&N,
|
||||||
|
values,
|
||||||
|
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||||
|
&lwork,
|
||||||
|
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||||
|
&lrwork,
|
||||||
|
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
||||||
|
&liwork,
|
||||||
|
&info);
|
||||||
|
if (jobz == 'V') {
|
||||||
|
// We have pre-transposed the vectors but we also must conjugate them
|
||||||
|
// when they are complex.
|
||||||
|
//
|
||||||
|
// We could vectorize this but it is so fast in comparison to heevd that
|
||||||
|
// it doesn't really matter.
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
for (int j = 0; j < N; j++) {
|
||||||
|
*vectors = std::conj(*vectors);
|
||||||
|
vectors++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void eigh_impl(
|
||||||
|
array& vectors,
|
||||||
|
array& values,
|
||||||
|
const std::string& uplo,
|
||||||
|
bool compute_eigenvectors,
|
||||||
|
Stream stream) {
|
||||||
|
using R = typename EighWork<T>::R;
|
||||||
|
|
||||||
|
auto vec_ptr = vectors.data<T>();
|
||||||
|
auto eig_ptr = values.data<R>();
|
||||||
|
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_output_array(vectors);
|
||||||
|
encoder.set_output_array(values);
|
||||||
|
encoder.dispatch([vec_ptr,
|
||||||
|
eig_ptr,
|
||||||
|
jobz,
|
||||||
|
uplo = uplo[0],
|
||||||
|
N = vectors.shape(-1),
|
||||||
|
size = vectors.size()]() mutable {
|
||||||
|
// Work query
|
||||||
|
EighWork<T> work(jobz, uplo, N);
|
||||||
|
|
||||||
|
// Work loop
|
||||||
|
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||||
|
work.run(vec_ptr, eig_ptr);
|
||||||
vec_ptr += N * N;
|
vec_ptr += N * N;
|
||||||
eig_ptr += N;
|
eig_ptr += N;
|
||||||
if (info != 0) {
|
if (work.info != 0) {
|
||||||
std::stringstream msg;
|
std::stringstream msg;
|
||||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||||
<< info;
|
<< work.info;
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
if (!compute_eigenvectors) {
|
||||||
|
encoder.add_temporary(vectors);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -84,12 +194,13 @@ void Eigh::eval_cpu(
|
|||||||
? outputs[1]
|
? outputs[1]
|
||||||
: array(a.shape(), a.dtype(), nullptr, {});
|
: array(a.shape(), a.dtype(), nullptr, {});
|
||||||
|
|
||||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
values.set_data(allocator::malloc(values.nbytes()));
|
||||||
|
|
||||||
copy(
|
copy_cpu(
|
||||||
a,
|
a,
|
||||||
vectors,
|
vectors,
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
stream());
|
||||||
|
|
||||||
if (compute_eigenvectors_) {
|
if (compute_eigenvectors_) {
|
||||||
// Set the strides and flags so the eigenvectors
|
// Set the strides and flags so the eigenvectors
|
||||||
@@ -107,14 +218,19 @@ void Eigh::eval_cpu(
|
|||||||
flags.col_contiguous = true;
|
flags.col_contiguous = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
vectors.copy_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||||
}
|
}
|
||||||
switch (a.dtype()) {
|
switch (a.dtype()) {
|
||||||
case float32:
|
case float32:
|
||||||
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_);
|
eigh_impl<float>(vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
break;
|
break;
|
||||||
case float64:
|
case float64:
|
||||||
eigh_impl<double>(vectors, values, uplo_, compute_eigenvectors_);
|
eigh_impl<double>(
|
||||||
|
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
|
break;
|
||||||
|
case complex64:
|
||||||
|
eigh_impl<std::complex<float>>(
|
||||||
|
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|||||||
16
mlx/backend/cpu/encoder.cpp
Normal file
16
mlx/backend/cpu/encoder.cpp
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
CommandEncoder& get_command_encoder(Stream stream) {
|
||||||
|
static std::unordered_map<int, CommandEncoder> encoder_map;
|
||||||
|
auto it = encoder_map.find(stream.index);
|
||||||
|
if (it == encoder_map.end()) {
|
||||||
|
it = encoder_map.emplace(stream.index, stream).first;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
||||||
67
mlx/backend/cpu/encoder.h
Normal file
67
mlx/backend/cpu/encoder.h
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
// Number of dispatches per scheduler task
|
||||||
|
constexpr int DISPATCHES_PER_TASK = 10;
|
||||||
|
|
||||||
|
struct CommandEncoder {
|
||||||
|
CommandEncoder(Stream stream) : stream_(stream) {}
|
||||||
|
|
||||||
|
CommandEncoder(const CommandEncoder&) = delete;
|
||||||
|
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||||
|
CommandEncoder(CommandEncoder&&) = delete;
|
||||||
|
CommandEncoder& operator=(CommandEncoder&&) = delete;
|
||||||
|
|
||||||
|
void set_input_array(const array& a) {}
|
||||||
|
void set_output_array(array& a) {}
|
||||||
|
|
||||||
|
// Hold onto a temporary until any already scheduled tasks which use it as
|
||||||
|
// an input are complete.
|
||||||
|
void add_temporary(array arr) {
|
||||||
|
temporaries_.push_back(std::move(arr));
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_temporaries(std::vector<array> arrays) {
|
||||||
|
temporaries_.insert(
|
||||||
|
temporaries_.end(),
|
||||||
|
std::make_move_iterator(arrays.begin()),
|
||||||
|
std::make_move_iterator(arrays.end()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array>& temporaries() {
|
||||||
|
return temporaries_;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class F, class... Args>
|
||||||
|
void dispatch(F&& f, Args&&... args) {
|
||||||
|
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
|
||||||
|
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
||||||
|
if (num_ops_ == 0) {
|
||||||
|
scheduler::notify_new_task(stream_);
|
||||||
|
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
|
||||||
|
task();
|
||||||
|
scheduler::notify_task_completion(s);
|
||||||
|
};
|
||||||
|
scheduler::enqueue(stream_, std::move(task_wrap));
|
||||||
|
} else {
|
||||||
|
scheduler::enqueue(stream_, std::move(task));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Stream stream_;
|
||||||
|
std::vector<array> temporaries_;
|
||||||
|
int num_ops_{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
CommandEncoder& get_command_encoder(Stream stream);
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
||||||
40
mlx/backend/cpu/eval.cpp
Normal file
40
mlx/backend/cpu/eval.cpp
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
#include "mlx/backend/cpu/eval.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
#include "mlx/scheduler.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
void eval(array& arr) {
|
||||||
|
auto s = arr.primitive().stream();
|
||||||
|
|
||||||
|
auto outputs = arr.outputs();
|
||||||
|
{
|
||||||
|
// If the array is a tracer hold a reference
|
||||||
|
// to its inputs so they don't get donated
|
||||||
|
std::vector<array> inputs;
|
||||||
|
if (arr.is_tracer()) {
|
||||||
|
inputs = arr.inputs();
|
||||||
|
}
|
||||||
|
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||||
|
for (auto& in : arr.inputs()) {
|
||||||
|
buffers.insert(in.data_shared_ptr());
|
||||||
|
}
|
||||||
|
for (auto& s : arr.siblings()) {
|
||||||
|
buffers.insert(s.data_shared_ptr());
|
||||||
|
}
|
||||||
|
// Remove the output if it was donated to by an input
|
||||||
|
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||||
|
buffers.erase(it);
|
||||||
|
}
|
||||||
|
auto& encoder = cpu::get_command_encoder(s);
|
||||||
|
encoder.dispatch([buffers = std::move(buffers),
|
||||||
|
temps = std::move(encoder.temporaries())]() {});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
||||||
12
mlx/backend/cpu/eval.h
Normal file
12
mlx/backend/cpu/eval.h
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
namespace mlx::core::cpu {
|
||||||
|
|
||||||
|
void eval(array& arr);
|
||||||
|
|
||||||
|
} // namespace mlx::core::cpu
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "mlx/3rdparty/pocketfft.h"
|
#include "mlx/3rdparty/pocketfft.h"
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -21,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
s *= out.itemsize();
|
s *= out.itemsize();
|
||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
std::vector<size_t> shape;
|
std::vector<size_t> shape;
|
||||||
if (out.dtype() == float32) {
|
if (out.dtype() == float32) {
|
||||||
@@ -38,46 +39,78 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
});
|
});
|
||||||
scale /= nelem;
|
scale /= nelem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
if (in.dtype() == complex64 && out.dtype() == complex64) {
|
if (in.dtype() == complex64 && out.dtype() == complex64) {
|
||||||
auto in_ptr =
|
auto in_ptr =
|
||||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||||
auto out_ptr =
|
auto out_ptr =
|
||||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||||
|
encoder.dispatch([shape = std::move(shape),
|
||||||
|
strides_in = std::move(strides_in),
|
||||||
|
strides_out = std::move(strides_out),
|
||||||
|
axes = axes_,
|
||||||
|
inverse = inverse_,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
scale]() {
|
||||||
pocketfft::c2c(
|
pocketfft::c2c(
|
||||||
shape,
|
shape,
|
||||||
strides_in,
|
strides_in,
|
||||||
strides_out,
|
strides_out,
|
||||||
axes_,
|
axes,
|
||||||
!inverse_,
|
!inverse,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
scale);
|
scale);
|
||||||
|
});
|
||||||
} else if (in.dtype() == float32 && out.dtype() == complex64) {
|
} else if (in.dtype() == float32 && out.dtype() == complex64) {
|
||||||
auto in_ptr = in.data<float>();
|
auto in_ptr = in.data<float>();
|
||||||
auto out_ptr =
|
auto out_ptr =
|
||||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||||
|
encoder.dispatch([shape = std::move(shape),
|
||||||
|
strides_in = std::move(strides_in),
|
||||||
|
strides_out = std::move(strides_out),
|
||||||
|
axes = axes_,
|
||||||
|
inverse = inverse_,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
scale]() {
|
||||||
pocketfft::r2c(
|
pocketfft::r2c(
|
||||||
shape,
|
shape,
|
||||||
strides_in,
|
strides_in,
|
||||||
strides_out,
|
strides_out,
|
||||||
axes_,
|
axes,
|
||||||
!inverse_,
|
!inverse,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
scale);
|
scale);
|
||||||
|
});
|
||||||
} else if (in.dtype() == complex64 && out.dtype() == float32) {
|
} else if (in.dtype() == complex64 && out.dtype() == float32) {
|
||||||
auto in_ptr =
|
auto in_ptr =
|
||||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||||
auto out_ptr = out.data<float>();
|
auto out_ptr = out.data<float>();
|
||||||
|
encoder.dispatch([shape = std::move(shape),
|
||||||
|
strides_in = std::move(strides_in),
|
||||||
|
strides_out = std::move(strides_out),
|
||||||
|
axes = axes_,
|
||||||
|
inverse = inverse_,
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
scale]() {
|
||||||
pocketfft::c2r(
|
pocketfft::c2r(
|
||||||
shape,
|
shape,
|
||||||
strides_in,
|
strides_in,
|
||||||
strides_out,
|
strides_out,
|
||||||
axes_,
|
axes,
|
||||||
!inverse_,
|
!inverse,
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
scale);
|
scale);
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[FFT] Received unexpected input and output type combination.");
|
"[FFT] Received unexpected input and output type combination.");
|
||||||
|
|||||||
@@ -7,14 +7,20 @@ namespace mlx::core {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void matmul(
|
void matmul(
|
||||||
const array& a,
|
const T* a,
|
||||||
const array& b,
|
const T* b,
|
||||||
array& out,
|
T* out,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
bool b_transposed,
|
bool b_transposed,
|
||||||
size_t lda,
|
size_t lda,
|
||||||
size_t ldb,
|
size_t ldb,
|
||||||
|
size_t ldc,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta);
|
float beta,
|
||||||
|
size_t batch_size,
|
||||||
|
const Shape& a_shape,
|
||||||
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -9,39 +9,46 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
template <typename T>
|
||||||
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
constexpr BNNSDataType to_bnns_dtype();
|
||||||
switch (kindof(mlx_dtype)) {
|
|
||||||
case Dtype::Kind::b:
|
template <>
|
||||||
return BNNSDataTypeBoolean;
|
constexpr BNNSDataType to_bnns_dtype<float>() {
|
||||||
case Dtype::Kind::u:
|
return BNNSDataType(BNNSDataTypeFloatBit | 32);
|
||||||
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
}
|
||||||
case Dtype::Kind::i:
|
template <>
|
||||||
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
constexpr BNNSDataType to_bnns_dtype<float16_t>() {
|
||||||
case Dtype::Kind::f:
|
return BNNSDataType(BNNSDataTypeFloatBit | 16);
|
||||||
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
|
||||||
case Dtype::Kind::V:
|
|
||||||
return BNNSDataTypeBFloat16;
|
|
||||||
case Dtype::Kind::c:
|
|
||||||
throw std::invalid_argument("BNNS does not support complex types");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
constexpr BNNSDataType to_bnns_dtype<bfloat16_t>() {
|
||||||
|
return BNNSDataTypeBFloat16;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
void matmul_bnns(
|
void matmul_bnns(
|
||||||
const array& a,
|
const T* a,
|
||||||
const array& b,
|
const T* b,
|
||||||
array& out,
|
T* out,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
bool b_transposed,
|
bool b_transposed,
|
||||||
size_t lda,
|
size_t lda,
|
||||||
size_t ldb,
|
size_t ldb,
|
||||||
|
size_t ldc,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta,
|
||||||
size_t M = a.shape(-2);
|
size_t batch_size,
|
||||||
size_t N = b.shape(-1);
|
const Shape& a_shape,
|
||||||
size_t K = a.shape(-1);
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides) {
|
||||||
|
auto ndim = a_shape.size();
|
||||||
|
size_t M = a_shape[ndim - 2];
|
||||||
|
size_t N = b_shape[ndim - 1];
|
||||||
|
size_t K = a_shape[ndim - 1];
|
||||||
|
|
||||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
BNNSDataType bnns_dtype = to_bnns_dtype<T>();
|
||||||
|
|
||||||
#pragma GCC diagnostic push
|
#pragma GCC diagnostic push
|
||||||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||||
@@ -115,14 +122,14 @@ void matmul_bnns(
|
|||||||
auto bnns_filter =
|
auto bnns_filter =
|
||||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||||
|
|
||||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
BNNSFilterApplyTwoInput(
|
BNNSFilterApplyTwoInput(
|
||||||
bnns_filter,
|
bnns_filter,
|
||||||
a.data<uint8_t>() +
|
reinterpret_cast<const uint8_t*>(
|
||||||
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
a + elem_to_loc(M * K * i, a_shape, a_strides)),
|
||||||
b.data<uint8_t>() +
|
reinterpret_cast<const uint8_t*>(
|
||||||
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
b + elem_to_loc(K * N * i, b_shape, b_strides)),
|
||||||
out.data<uint8_t>() + M * N * i * out.itemsize());
|
reinterpret_cast<uint8_t*>(out + M * N * i));
|
||||||
}
|
}
|
||||||
|
|
||||||
BNNSFilterDestroy(bnns_filter);
|
BNNSFilterDestroy(bnns_filter);
|
||||||
@@ -131,30 +138,72 @@ void matmul_bnns(
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
void matmul<float16_t>(
|
void matmul<float16_t>(
|
||||||
const array& a,
|
const float16_t* a,
|
||||||
const array& b,
|
const float16_t* b,
|
||||||
array& out,
|
float16_t* out,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
bool b_transposed,
|
bool b_transposed,
|
||||||
size_t lda,
|
size_t lda,
|
||||||
size_t ldb,
|
size_t ldb,
|
||||||
|
size_t ldc,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta,
|
||||||
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
size_t batch_size,
|
||||||
|
const Shape& a_shape,
|
||||||
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides) {
|
||||||
|
matmul_bnns(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc,
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
|
batch_size,
|
||||||
|
a_shape,
|
||||||
|
a_strides,
|
||||||
|
b_shape,
|
||||||
|
b_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void matmul<bfloat16_t>(
|
void matmul<bfloat16_t>(
|
||||||
const array& a,
|
const bfloat16_t* a,
|
||||||
const array& b,
|
const bfloat16_t* b,
|
||||||
array& out,
|
bfloat16_t* out,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
bool b_transposed,
|
bool b_transposed,
|
||||||
size_t lda,
|
size_t lda,
|
||||||
size_t ldb,
|
size_t ldb,
|
||||||
|
size_t ldc,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta,
|
||||||
matmul_bnns(a, b, out, a_transposed, b_transposed, lda, ldb, alpha, beta);
|
size_t batch_size,
|
||||||
|
const Shape& a_shape,
|
||||||
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides) {
|
||||||
|
matmul_bnns(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
lda,
|
||||||
|
ldb,
|
||||||
|
ldc,
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
|
batch_size,
|
||||||
|
a_shape,
|
||||||
|
a_strides,
|
||||||
|
b_shape,
|
||||||
|
b_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -8,20 +8,27 @@ namespace mlx::core {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
void matmul<float>(
|
void matmul<float>(
|
||||||
const array& a,
|
const float* a,
|
||||||
const array& b,
|
const float* b,
|
||||||
array& out,
|
float* out,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
bool b_transposed,
|
bool b_transposed,
|
||||||
size_t lda,
|
size_t lda,
|
||||||
size_t ldb,
|
size_t ldb,
|
||||||
|
size_t ldc,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta,
|
||||||
size_t M = a.shape(-2);
|
size_t batch_size,
|
||||||
size_t N = b.shape(-1);
|
const Shape& a_shape,
|
||||||
size_t K = a.shape(-1);
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides) {
|
||||||
|
auto ndim = a_shape.size();
|
||||||
|
size_t M = a_shape[ndim - 2];
|
||||||
|
size_t N = b_shape[ndim - 1];
|
||||||
|
size_t K = a_shape[ndim - 1];
|
||||||
|
|
||||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
cblas_sgemm(
|
cblas_sgemm(
|
||||||
CblasRowMajor,
|
CblasRowMajor,
|
||||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||||
@@ -29,34 +36,40 @@ void matmul<float>(
|
|||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
alpha, // alpha
|
alpha,
|
||||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||||
lda,
|
lda,
|
||||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||||
ldb,
|
ldb,
|
||||||
beta, // beta
|
beta,
|
||||||
out.data<float>() + M * N * i,
|
out + M * N * i,
|
||||||
out.shape(-1) // ldc
|
ldc);
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void matmul<double>(
|
void matmul<double>(
|
||||||
const array& a,
|
const double* a,
|
||||||
const array& b,
|
const double* b,
|
||||||
array& out,
|
double* out,
|
||||||
bool a_transposed,
|
bool a_transposed,
|
||||||
bool b_transposed,
|
bool b_transposed,
|
||||||
size_t lda,
|
size_t lda,
|
||||||
size_t ldb,
|
size_t ldb,
|
||||||
|
size_t ldc,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta,
|
||||||
size_t M = a.shape(-2);
|
size_t batch_size,
|
||||||
size_t N = b.shape(-1);
|
const Shape& a_shape,
|
||||||
size_t K = a.shape(-1);
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides) {
|
||||||
|
auto ndim = a_shape.size();
|
||||||
|
size_t M = a_shape[ndim - 2];
|
||||||
|
size_t N = b_shape[ndim - 1];
|
||||||
|
size_t K = a_shape[ndim - 1];
|
||||||
|
|
||||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
cblas_dgemm(
|
cblas_dgemm(
|
||||||
CblasRowMajor,
|
CblasRowMajor,
|
||||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||||
@@ -64,15 +77,14 @@ void matmul<double>(
|
|||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
alpha, // alpha
|
alpha,
|
||||||
a.data<double>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||||
lda,
|
lda,
|
||||||
b.data<double>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||||
ldb,
|
ldb,
|
||||||
beta, // beta
|
beta,
|
||||||
out.data<double>() + M * N * i,
|
out + M * N * i,
|
||||||
out.shape(-1) // ldc
|
ldc);
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cpu/gemm.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
void matmul<bfloat16_t>(
|
|
||||||
const array&,
|
|
||||||
const array&,
|
|
||||||
array&,
|
|
||||||
bool,
|
|
||||||
bool,
|
|
||||||
size_t,
|
|
||||||
size_t,
|
|
||||||
float,
|
|
||||||
float) {
|
|
||||||
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/cpu/gemm.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
void matmul<float16_t>(
|
|
||||||
const array&,
|
|
||||||
const array&,
|
|
||||||
array&,
|
|
||||||
bool,
|
|
||||||
bool,
|
|
||||||
size_t,
|
|
||||||
size_t,
|
|
||||||
float,
|
|
||||||
float) {
|
|
||||||
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
45
mlx/backend/cpu/gemms/simd_bf16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_bf16.cpp
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cpu/gemm.h"
|
||||||
|
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void matmul<bfloat16_t>(
|
||||||
|
const bfloat16_t* a,
|
||||||
|
const bfloat16_t* b,
|
||||||
|
bfloat16_t* out,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed,
|
||||||
|
size_t lda,
|
||||||
|
size_t ldb,
|
||||||
|
size_t ldc,
|
||||||
|
float alpha,
|
||||||
|
float beta,
|
||||||
|
size_t batch_size,
|
||||||
|
const Shape& a_shape,
|
||||||
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides) {
|
||||||
|
auto ndim = a_shape.size();
|
||||||
|
size_t M = a_shape[ndim - 2];
|
||||||
|
size_t N = b_shape[ndim - 1];
|
||||||
|
size_t K = a_shape[ndim - 1];
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
simd_gemm<bfloat16_t, float>(
|
||||||
|
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||||
|
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||||
|
out + M * N * i,
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
45
mlx/backend/cpu/gemms/simd_fp16.cpp
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
#include "mlx/backend/cpu/gemm.h"
|
||||||
|
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void matmul<float16_t>(
|
||||||
|
const float16_t* a,
|
||||||
|
const float16_t* b,
|
||||||
|
float16_t* out,
|
||||||
|
bool a_transposed,
|
||||||
|
bool b_transposed,
|
||||||
|
size_t lda,
|
||||||
|
size_t ldb,
|
||||||
|
size_t ldc,
|
||||||
|
float alpha,
|
||||||
|
float beta,
|
||||||
|
size_t batch_size,
|
||||||
|
const Shape& a_shape,
|
||||||
|
const Strides& a_strides,
|
||||||
|
const Shape& b_shape,
|
||||||
|
const Strides& b_strides) {
|
||||||
|
auto ndim = a_shape.size();
|
||||||
|
size_t M = a_shape[ndim - 2];
|
||||||
|
size_t N = b_shape[ndim - 1];
|
||||||
|
size_t K = a_shape[ndim - 1];
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
simd_gemm<float16_t, float>(
|
||||||
|
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||||
|
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||||
|
out + M * N * i,
|
||||||
|
a_transposed,
|
||||||
|
b_transposed,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
alpha,
|
||||||
|
beta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
139
mlx/backend/cpu/gemms/simd_gemm.h
Normal file
139
mlx/backend/cpu/gemms/simd_gemm.h
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cpu/simd/simd.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
inline int ceildiv(int a, int b) {
|
||||||
|
return (a + b - 1) / b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int block_size, typename T, typename AccT>
|
||||||
|
void load_block(
|
||||||
|
const T* in,
|
||||||
|
AccT* out,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int i,
|
||||||
|
int j,
|
||||||
|
bool transpose) {
|
||||||
|
if (transpose) {
|
||||||
|
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||||
|
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||||
|
out[jj * block_size + ii] =
|
||||||
|
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||||
|
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||||
|
out[ii * block_size + jj] =
|
||||||
|
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename AccT>
|
||||||
|
void simd_gemm(
|
||||||
|
const T* a,
|
||||||
|
const T* b,
|
||||||
|
T* c,
|
||||||
|
bool a_trans,
|
||||||
|
bool b_trans,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
float alpha,
|
||||||
|
float beta) {
|
||||||
|
constexpr int block_size = 16;
|
||||||
|
constexpr int simd_size = simd::max_size<AccT>;
|
||||||
|
static_assert(
|
||||||
|
(block_size % simd_size) == 0,
|
||||||
|
"Block size must be divisible by SIMD size");
|
||||||
|
|
||||||
|
int last_k_block_size = K - block_size * (K / block_size);
|
||||||
|
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
|
||||||
|
for (int i = 0; i < ceildiv(M, block_size); i++) {
|
||||||
|
for (int j = 0; j < ceildiv(N, block_size); j++) {
|
||||||
|
AccT c_block[block_size * block_size] = {0.0};
|
||||||
|
AccT a_block[block_size * block_size];
|
||||||
|
AccT b_block[block_size * block_size];
|
||||||
|
|
||||||
|
int k = 0;
|
||||||
|
for (; k < K / block_size; k++) {
|
||||||
|
// Load a and b blocks
|
||||||
|
if (a_trans) {
|
||||||
|
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||||
|
} else {
|
||||||
|
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||||
|
}
|
||||||
|
if (b_trans) {
|
||||||
|
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||||
|
} else {
|
||||||
|
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply and accumulate
|
||||||
|
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||||
|
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||||
|
for (int kk = 0; kk < block_size; kk += simd_size) {
|
||||||
|
auto av =
|
||||||
|
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||||
|
auto bv =
|
||||||
|
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||||
|
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (last_k_block_size) {
|
||||||
|
// Load a and b blocks
|
||||||
|
if (a_trans) {
|
||||||
|
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||||
|
} else {
|
||||||
|
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||||
|
}
|
||||||
|
if (b_trans) {
|
||||||
|
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||||
|
} else {
|
||||||
|
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply and accumulate
|
||||||
|
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||||
|
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||||
|
int kk = 0;
|
||||||
|
for (; kk < last_k_simd_block; kk += simd_size) {
|
||||||
|
auto av =
|
||||||
|
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||||
|
auto bv =
|
||||||
|
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||||
|
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||||
|
}
|
||||||
|
for (; kk < last_k_block_size; ++kk) {
|
||||||
|
c_block[ii * block_size + jj] +=
|
||||||
|
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store
|
||||||
|
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||||
|
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||||
|
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
|
||||||
|
if (beta != 0) {
|
||||||
|
c[c_idx] = static_cast<T>(
|
||||||
|
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
|
||||||
|
} else {
|
||||||
|
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -4,16 +4,17 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/hadamard.h"
|
#include "mlx/backend/common/hadamard.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// n = 2^k component
|
// n = 2^k component
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void hadamard_n(array& out, int n, int m, float scale) {
|
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||||
for (int b = 0; b < out.size() / n; b++) {
|
for (int b = 0; b < size / n; b++) {
|
||||||
size_t loc = b * n;
|
size_t loc = b * n;
|
||||||
T* data_ptr = out.data<T>() + loc;
|
T* data_ptr = out + loc;
|
||||||
int h = 1;
|
int h = 1;
|
||||||
int n_over_2 = n / 2;
|
int n_over_2 = n / 2;
|
||||||
while (h < n) {
|
while (h < n) {
|
||||||
@@ -36,7 +37,7 @@ void hadamard_n(array& out, int n, int m, float scale) {
|
|||||||
|
|
||||||
// m component
|
// m component
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void hadamard_m(array& out, int n, int m, float scale) {
|
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||||
auto h_matrices = hadamard_matrices();
|
auto h_matrices = hadamard_matrices();
|
||||||
auto& matrix = h_matrices[m];
|
auto& matrix = h_matrices[m];
|
||||||
auto start = 1;
|
auto start = 1;
|
||||||
@@ -51,9 +52,9 @@ void hadamard_m(array& out, int n, int m, float scale) {
|
|||||||
end = matrix.find('\n', start);
|
end = matrix.find('\n', start);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int b = 0; b < out.size() / m / n; b++) {
|
for (int b = 0; b < size / m / n; b++) {
|
||||||
size_t loc = b * n * m;
|
size_t loc = b * n * m;
|
||||||
T* data_ptr = out.data<T>() + loc;
|
T* data_ptr = out + loc;
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
std::vector<float> out(m);
|
std::vector<float> out(m);
|
||||||
for (int j = 0; j < m; j++) {
|
for (int j = 0; j < m; j++) {
|
||||||
@@ -74,12 +75,17 @@ void hadamard_m(array& out, int n, int m, float scale) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void hadamard(array& out, int n, int m, float scale) {
|
void hadamard(array& out, int n, int m, float scale, Stream stream) {
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
auto out_ptr = out.data<T>();
|
||||||
|
encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() {
|
||||||
float n_scale = m > 1 ? 1.0 : scale;
|
float n_scale = m > 1 ? 1.0 : scale;
|
||||||
hadamard_n<T>(out, n, m, n_scale);
|
hadamard_n<T>(out_ptr, n, m, n_scale, size);
|
||||||
if (m > 1) {
|
if (m > 1) {
|
||||||
hadamard_m<T>(out, n, m, scale);
|
hadamard_m<T>(out_ptr, n, m, scale, size);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -87,18 +93,26 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Copy input to output
|
// Copy input to output
|
||||||
copy(in, out, CopyType::General);
|
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
copy_cpu(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
stream());
|
||||||
|
}
|
||||||
|
|
||||||
int axis = out.ndim() - 1;
|
int axis = out.ndim() - 1;
|
||||||
auto [n, m] = decompose_hadamard(out.shape(axis));
|
auto [n, m] = decompose_hadamard(out.shape(axis));
|
||||||
|
|
||||||
switch (in.dtype()) {
|
switch (in.dtype()) {
|
||||||
case float32:
|
case float32:
|
||||||
return hadamard<float>(out, n, m, scale_);
|
return hadamard<float>(out, n, m, scale_, stream());
|
||||||
case float16:
|
case float16:
|
||||||
return hadamard<float16_t>(out, n, m, scale_);
|
return hadamard<float16_t>(out, n, m, scale_, stream());
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return hadamard<bfloat16_t>(out, n, m, scale_);
|
return hadamard<bfloat16_t>(out, n, m, scale_, stream());
|
||||||
default:
|
default:
|
||||||
throw std::invalid_argument("[hadamard] Unsupported type.");
|
throw std::invalid_argument("[hadamard] Unsupported type.");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -21,6 +22,40 @@ inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
|||||||
return idx;
|
return idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct None {
|
||||||
|
template <typename T>
|
||||||
|
void operator()(T x, T* y) {
|
||||||
|
(*y) = x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct Sum {
|
||||||
|
template <typename T>
|
||||||
|
void operator()(T x, T* y) {
|
||||||
|
(*y) += x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Prod {
|
||||||
|
template <typename T>
|
||||||
|
void operator()(T x, T* y) {
|
||||||
|
(*y) *= x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Max {
|
||||||
|
template <typename T>
|
||||||
|
void operator()(T x, T* y) {
|
||||||
|
(*y) = (*y > x) ? *y : x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Min {
|
||||||
|
template <typename T>
|
||||||
|
void operator()(T x, T* y) {
|
||||||
|
(*y) = (*y < x) ? *y : x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, typename IdxT>
|
template <typename T, typename IdxT>
|
||||||
void gather(
|
void gather(
|
||||||
const array& src,
|
const array& src,
|
||||||
@@ -73,13 +108,14 @@ void gather(
|
|||||||
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||||
const T* src_ptr = src.data<T>();
|
const T* src_ptr = src.data<T>();
|
||||||
T* dst_ptr = out.data<T>();
|
T* dst_ptr = out.data<T>();
|
||||||
size_t out_idx = 0;
|
|
||||||
|
|
||||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||||
ContiguousIterator src_it;
|
ContiguousIterator src_it;
|
||||||
if (!can_copy && src.ndim() > 0) {
|
if (!can_copy && src.ndim() > 0) {
|
||||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t out_idx = 0;
|
||||||
for (int idx = 0; idx < ind_size; idx++) {
|
for (int idx = 0; idx < ind_size; idx++) {
|
||||||
size_t src_idx = 0;
|
size_t src_idx = 0;
|
||||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||||
@@ -161,11 +197,23 @@ void dispatch_gather(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& src = inputs[0];
|
auto& src = inputs[0];
|
||||||
std::vector<array> inds(inputs.begin() + 1, inputs.end());
|
std::vector<array> inds;
|
||||||
|
for (auto it = inputs.begin() + 1; it < inputs.end(); ++it) {
|
||||||
|
inds.push_back(array::unsafe_weak_copy(*it));
|
||||||
|
}
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
encoder.set_input_array(in);
|
||||||
|
}
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([axes_ = axes_,
|
||||||
|
slice_sizes_ = slice_sizes_,
|
||||||
|
src = array::unsafe_weak_copy(src),
|
||||||
|
inds = std::move(inds),
|
||||||
|
out = array::unsafe_weak_copy(out)]() mutable {
|
||||||
if (inds.empty()) {
|
if (inds.empty()) {
|
||||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||||
return;
|
return;
|
||||||
@@ -201,6 +249,7 @@ void Gather::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
"[Gather::eval_cpu] Cannot gather with indices type.");
|
"[Gather::eval_cpu] Cannot gather with indices type.");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
template <typename T, typename IdxT>
|
template <typename T, typename IdxT>
|
||||||
void gather_axis(
|
void gather_axis(
|
||||||
@@ -208,15 +257,11 @@ void gather_axis(
|
|||||||
const array& ind,
|
const array& ind,
|
||||||
array& out,
|
array& out,
|
||||||
const int axis) {
|
const int axis) {
|
||||||
auto strides = ind.strides();
|
auto shape = remove_index(ind.shape(), axis);
|
||||||
strides.erase(strides.begin() + axis);
|
ContiguousIterator ind_it(
|
||||||
auto shape = ind.shape();
|
shape, remove_index(ind.strides(), axis), src.ndim() - 1);
|
||||||
shape.erase(shape.begin() + axis);
|
ContiguousIterator src_it(
|
||||||
ContiguousIterator ind_it(shape, strides, src.ndim() - 1);
|
shape, remove_index(src.strides(), axis), src.ndim() - 1);
|
||||||
|
|
||||||
strides = src.strides();
|
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
ContiguousIterator src_it(shape, strides, src.ndim() - 1);
|
|
||||||
|
|
||||||
auto ind_ptr = ind.data<IdxT>();
|
auto ind_ptr = ind.data<IdxT>();
|
||||||
auto src_ptr = src.data<T>();
|
auto src_ptr = src.data<T>();
|
||||||
@@ -235,6 +280,7 @@ void gather_axis(
|
|||||||
for (int i = axis + 1; i < ind.ndim(); ++i) {
|
for (int i = axis + 1; i < ind.ndim(); ++i) {
|
||||||
size_post *= ind.shape(i);
|
size_post *= ind.shape(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t stride_pre = size_post * ind_ax_size;
|
size_t stride_pre = size_post * ind_ax_size;
|
||||||
for (size_t i = 0; i < size_pre; i++) {
|
for (size_t i = 0; i < size_pre; i++) {
|
||||||
for (size_t k = 0; k < size_post; k++) {
|
for (size_t k = 0; k < size_post; k++) {
|
||||||
@@ -304,9 +350,18 @@ void dispatch_gather_axis(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
auto& src = inputs[0];
|
auto& src = inputs[0];
|
||||||
auto& inds = inputs[1];
|
auto& inds = inputs[1];
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
encoder.set_input_array(src);
|
||||||
|
encoder.set_input_array(inds);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([axis_ = axis_,
|
||||||
|
src = array::unsafe_weak_copy(src),
|
||||||
|
inds = array::unsafe_weak_copy(inds),
|
||||||
|
out = array::unsafe_weak_copy(out)]() mutable {
|
||||||
switch (inds.dtype()) {
|
switch (inds.dtype()) {
|
||||||
case uint8:
|
case uint8:
|
||||||
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
|
dispatch_gather_axis<uint8_t>(src, inds, out, axis_);
|
||||||
@@ -337,6 +392,7 @@ void GatherAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
|
"[GatherAxis::eval_cpu] Cannot gather with indices type.");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename InT, typename IdxT, typename OpT>
|
template <typename InT, typename IdxT, typename OpT>
|
||||||
@@ -344,8 +400,7 @@ void scatter(
|
|||||||
const array& updates,
|
const array& updates,
|
||||||
array& out,
|
array& out,
|
||||||
const std::vector<array>& inds,
|
const std::vector<array>& inds,
|
||||||
const std::vector<int>& axes,
|
const std::vector<int>& axes) {
|
||||||
const OpT& op) {
|
|
||||||
int nind = inds.size();
|
int nind = inds.size();
|
||||||
auto inds_ndim = updates.ndim() - out.ndim();
|
auto inds_ndim = updates.ndim() - out.ndim();
|
||||||
size_t n_updates = nind ? inds[0].size() : 1;
|
size_t n_updates = nind ? inds[0].size() : 1;
|
||||||
@@ -361,9 +416,11 @@ void scatter(
|
|||||||
ContiguousIterator update_it(updates);
|
ContiguousIterator update_it(updates);
|
||||||
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
|
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
|
||||||
|
|
||||||
|
auto out_ptr = out.data<InT>();
|
||||||
|
auto upd_ptr = updates.data<InT>();
|
||||||
for (int i = 0; i < n_updates; ++i) {
|
for (int i = 0; i < n_updates; ++i) {
|
||||||
size_t out_offset = 0;
|
size_t out_offset = 0;
|
||||||
for (int j = 0; j < nind; ++j) {
|
for (int j = 0; j < inds.size(); ++j) {
|
||||||
auto ax = axes[j];
|
auto ax = axes[j];
|
||||||
auto idx_loc = its[j].loc;
|
auto idx_loc = its[j].loc;
|
||||||
its[j].step();
|
its[j].step();
|
||||||
@@ -373,8 +430,7 @@ void scatter(
|
|||||||
}
|
}
|
||||||
update_it.seek(i * update_size);
|
update_it.seek(i * update_size);
|
||||||
for (int j = 0; j < update_size; ++j) {
|
for (int j = 0; j < update_size; ++j) {
|
||||||
op(updates.data<InT>()[update_it.loc],
|
OpT{}(upd_ptr[update_it.loc], out_ptr + out_offset + out_it.loc);
|
||||||
out.data<InT>() + out_offset + out_it.loc);
|
|
||||||
update_it.step();
|
update_it.step();
|
||||||
out_it.step();
|
out_it.step();
|
||||||
}
|
}
|
||||||
@@ -392,26 +448,19 @@ void dispatch_scatter_inds(
|
|||||||
Scatter::ReduceType rtype) {
|
Scatter::ReduceType rtype) {
|
||||||
switch (rtype) {
|
switch (rtype) {
|
||||||
case Scatter::None:
|
case Scatter::None:
|
||||||
scatter<InT, IdxT>(
|
scatter<InT, IdxT, None>(updates, out, indices, axes);
|
||||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
|
|
||||||
break;
|
break;
|
||||||
case Scatter::Sum:
|
case Scatter::Sum:
|
||||||
scatter<InT, IdxT>(
|
scatter<InT, IdxT, Sum>(updates, out, indices, axes);
|
||||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
|
|
||||||
break;
|
break;
|
||||||
case Scatter::Prod:
|
case Scatter::Prod:
|
||||||
scatter<InT, IdxT>(
|
scatter<InT, IdxT, Prod>(updates, out, indices, axes);
|
||||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
|
|
||||||
break;
|
break;
|
||||||
case Scatter::Max:
|
case Scatter::Max:
|
||||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
scatter<InT, IdxT, Max>(updates, out, indices, axes);
|
||||||
(*y) = (*y > x) ? *y : x;
|
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
case Scatter::Min:
|
case Scatter::Min:
|
||||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
scatter<InT, IdxT, Min>(updates, out, indices, axes);
|
||||||
(*y) = (*y < x) ? *y : x;
|
|
||||||
});
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -463,15 +512,27 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() >= 2);
|
assert(inputs.size() >= 2);
|
||||||
|
|
||||||
auto& src = inputs[0];
|
auto& src = inputs[0];
|
||||||
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
|
|
||||||
auto& updates = inputs.back();
|
auto& updates = inputs.back();
|
||||||
|
|
||||||
// Copy src into out (copy allocates memory for out)
|
// Copy src into out (copy allocates memory for out)
|
||||||
auto ctype =
|
auto ctype =
|
||||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(src, out, ctype);
|
copy_cpu(src, out, ctype, stream());
|
||||||
|
|
||||||
switch (src.dtype()) {
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
std::vector<array> inds;
|
||||||
|
for (auto it = inputs.begin() + 1; it < inputs.end() - 1; ++it) {
|
||||||
|
encoder.set_input_array(*it);
|
||||||
|
inds.push_back(array::unsafe_weak_copy(*it));
|
||||||
|
}
|
||||||
|
encoder.set_input_array(updates);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([axes_ = axes_,
|
||||||
|
reduce_type_ = reduce_type_,
|
||||||
|
updates = array::unsafe_weak_copy(updates),
|
||||||
|
inds = std::move(inds),
|
||||||
|
out = array::unsafe_weak_copy(out)]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
|
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
@@ -515,24 +576,16 @@ void Scatter::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
|
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename IdxT, typename OpT>
|
template <typename T, typename IdxT, typename OpT>
|
||||||
void scatter_axis(
|
void scatter_axis(array& out, const array idx, const array& upd, int axis) {
|
||||||
array& out,
|
auto shape = remove_index(idx.shape(), axis);
|
||||||
const array idx,
|
ContiguousIterator idx_it(
|
||||||
const array& upd,
|
shape, remove_index(idx.strides(), axis), upd.ndim() - 1);
|
||||||
int axis,
|
ContiguousIterator upd_it(
|
||||||
const OpT& op) {
|
shape, remove_index(upd.strides(), axis), upd.ndim() - 1);
|
||||||
auto strides = idx.strides();
|
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
auto shape = idx.shape();
|
|
||||||
shape.erase(shape.begin() + axis);
|
|
||||||
ContiguousIterator idx_it(shape, strides, upd.ndim() - 1);
|
|
||||||
|
|
||||||
strides = upd.strides();
|
|
||||||
strides.erase(strides.begin() + axis);
|
|
||||||
ContiguousIterator upd_it(shape, strides, upd.ndim() - 1);
|
|
||||||
|
|
||||||
auto idx_ptr = idx.data<IdxT>();
|
auto idx_ptr = idx.data<IdxT>();
|
||||||
auto upd_ptr = upd.data<T>();
|
auto upd_ptr = upd.data<T>();
|
||||||
@@ -557,7 +610,8 @@ void scatter_axis(
|
|||||||
for (int j = 0; j < idx_ax_size; ++j) {
|
for (int j = 0; j < idx_ax_size; ++j) {
|
||||||
auto ind_val = offset_neg_idx(
|
auto ind_val = offset_neg_idx(
|
||||||
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
|
idx_ptr[idx_it.loc + j * idx_ax_stride], dst_ax_size);
|
||||||
op(upd_ptr[upd_it.loc + j * upd_ax_stride],
|
OpT{}(
|
||||||
|
upd_ptr[upd_it.loc + j * upd_ax_stride],
|
||||||
dst_ptr + k + ind_val * dst_ax_stride);
|
dst_ptr + k + ind_val * dst_ax_stride);
|
||||||
}
|
}
|
||||||
idx_it.step();
|
idx_it.step();
|
||||||
@@ -576,12 +630,10 @@ void dispatch_scatter_axis_op(
|
|||||||
ScatterAxis::ReduceType rtype) {
|
ScatterAxis::ReduceType rtype) {
|
||||||
switch (rtype) {
|
switch (rtype) {
|
||||||
case ScatterAxis::None:
|
case ScatterAxis::None:
|
||||||
scatter_axis<InT, IdxT>(
|
scatter_axis<InT, IdxT, None>(out, idx, updates, axis);
|
||||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) = x; });
|
|
||||||
break;
|
break;
|
||||||
case ScatterAxis::Sum:
|
case ScatterAxis::Sum:
|
||||||
scatter_axis<InT, IdxT>(
|
scatter_axis<InT, IdxT, Sum>(out, idx, updates, axis);
|
||||||
out, idx, updates, axis, [](auto x, auto* y) { (*y) += x; });
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -634,9 +686,18 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// Copy src into out (copy allocates memory for out)
|
// Copy src into out (copy allocates memory for out)
|
||||||
auto ctype =
|
auto ctype =
|
||||||
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
src.flags().row_contiguous ? CopyType::Vector : CopyType::General;
|
||||||
copy(src, out, ctype);
|
copy_cpu(src, out, ctype, stream());
|
||||||
|
|
||||||
switch (src.dtype()) {
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
|
encoder.set_input_array(idx);
|
||||||
|
encoder.set_input_array(updates);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([axis_ = axis_,
|
||||||
|
reduce_type_ = reduce_type_,
|
||||||
|
idx = array::unsafe_weak_copy(idx),
|
||||||
|
updates = array::unsafe_weak_copy(updates),
|
||||||
|
out = array::unsafe_weak_copy(out)]() mutable {
|
||||||
|
switch (out.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
|
dispatch_scatter_axis<bool>(out, idx, updates, axis_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
@@ -665,7 +726,8 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
|
dispatch_scatter_axis<int64_t>(out, idx, updates, axis_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
dispatch_scatter_axis<float16_t>(out, idx, updates, axis_, reduce_type_);
|
dispatch_scatter_axis<float16_t>(
|
||||||
|
out, idx, updates, axis_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
case float32:
|
case float32:
|
||||||
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
dispatch_scatter_axis<float>(out, idx, updates, axis_, reduce_type_);
|
||||||
@@ -674,13 +736,15 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
|
dispatch_scatter_axis<double>(out, idx, updates, axis_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
dispatch_scatter_axis<bfloat16_t>(out, idx, updates, axis_, reduce_type_);
|
dispatch_scatter_axis<bfloat16_t>(
|
||||||
|
out, idx, updates, axis_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
dispatch_scatter_axis<complex64_t>(
|
dispatch_scatter_axis<complex64_t>(
|
||||||
out, idx, updates, axis_, reduce_type_);
|
out, idx, updates, axis_, reduce_type_);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -2,20 +2,21 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
#include "mlx/backend/cpu/lapack.h"
|
#include "mlx/backend/cpu/lapack.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void general_inv(array& inv, int N, int i) {
|
void general_inv(T* inv, int N) {
|
||||||
int info;
|
int info;
|
||||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
|
||||||
// Compute LU factorization.
|
// Compute LU factorization.
|
||||||
getrf<T>(
|
getrf<T>(
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
/* n = */ &N,
|
/* n = */ &N,
|
||||||
/* a = */ inv.data<T>() + N * N * i,
|
/* a = */ inv,
|
||||||
/* lda = */ &N,
|
/* lda = */ &N,
|
||||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||||
/* info = */ &info);
|
/* info = */ &info);
|
||||||
@@ -48,12 +49,12 @@ void general_inv(array& inv, int N, int i) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int lwork = workspace_size;
|
const int lwork = workspace_size;
|
||||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||||
|
|
||||||
// Compute inverse.
|
// Compute inverse.
|
||||||
getri<T>(
|
getri<T>(
|
||||||
/* m = */ &N,
|
/* m = */ &N,
|
||||||
/* a = */ inv.data<T>() + N * N * i,
|
/* a = */ inv,
|
||||||
/* lda = */ &N,
|
/* lda = */ &N,
|
||||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||||
@@ -68,29 +69,28 @@ void general_inv(array& inv, int N, int i) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
void tri_inv(T* inv, int N, bool upper) {
|
||||||
const char uplo = upper ? 'L' : 'U';
|
const char uplo = upper ? 'L' : 'U';
|
||||||
const char diag = 'N';
|
const char diag = 'N';
|
||||||
T* data = inv.data<T>() + N * N * i;
|
|
||||||
int info;
|
int info;
|
||||||
trtri<T>(
|
trtri<T>(
|
||||||
/* uplo = */ &uplo,
|
/* uplo = */ &uplo,
|
||||||
/* diag = */ &diag,
|
/* diag = */ &diag,
|
||||||
/* N = */ &N,
|
/* N = */ &N,
|
||||||
/* a = */ data,
|
/* a = */ inv,
|
||||||
/* lda = */ &N,
|
/* lda = */ &N,
|
||||||
/* info = */ &info);
|
/* info = */ &info);
|
||||||
|
|
||||||
// zero out the other triangle
|
// zero out the other triangle
|
||||||
if (upper) {
|
if (upper) {
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
std::fill(data, data + i, 0.0f);
|
std::fill(inv, inv + i, 0.0f);
|
||||||
data += N;
|
inv += N;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
std::fill(data + i + 1, data + N, 0.0f);
|
std::fill(inv + i + 1, inv + N, 0.0f);
|
||||||
data += N;
|
inv += N;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,34 +103,53 @@ void tri_inv(array& inv, int N, int i, bool upper) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
void inverse_impl(
|
||||||
|
const array& a,
|
||||||
|
array& inv,
|
||||||
|
bool tri,
|
||||||
|
bool upper,
|
||||||
|
Stream stream) {
|
||||||
// Lapack uses the column-major convention. We take advantage of the following
|
// Lapack uses the column-major convention. We take advantage of the following
|
||||||
// identity to avoid transposing (see
|
// identity to avoid transposing (see
|
||||||
// https://math.stackexchange.com/a/340234):
|
// https://math.stackexchange.com/a/340234):
|
||||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||||
|
|
||||||
// The inverse is computed in place, so just copy the input to the output.
|
// The inverse is computed in place, so just copy the input to the output.
|
||||||
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
copy_cpu(
|
||||||
|
a,
|
||||||
|
inv,
|
||||||
|
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||||
|
stream);
|
||||||
|
|
||||||
const int N = a.shape(-1);
|
const int N = a.shape(-1);
|
||||||
const size_t num_matrices = a.size() / (N * N);
|
const size_t num_matrices = a.size() / (N * N);
|
||||||
|
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_output_array(inv);
|
||||||
|
|
||||||
|
auto inv_ptr = inv.data<T>();
|
||||||
if (tri) {
|
if (tri) {
|
||||||
tri_inv<T>(inv, N, i, upper);
|
encoder.dispatch([inv_ptr, N, num_matrices, upper]() {
|
||||||
} else {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
general_inv<T>(inv, N, i);
|
tri_inv<T>(inv_ptr + N * N * i, N, upper);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
encoder.dispatch([inv_ptr, N, num_matrices]() {
|
||||||
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
|
general_inv<T>(inv_ptr + N * N * i, N);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
|
void Inverse::eval_cpu(const std::vector<array>& inputs, array& output) {
|
||||||
switch (inputs[0].dtype()) {
|
switch (inputs[0].dtype()) {
|
||||||
case float32:
|
case float32:
|
||||||
inverse_impl<float>(inputs[0], output, tri_, upper_);
|
inverse_impl<float>(inputs[0], output, tri_, upper_, stream());
|
||||||
break;
|
break;
|
||||||
case float64:
|
case float64:
|
||||||
inverse_impl<double>(inputs[0], output, tri_, upper_);
|
inverse_impl<double>(inputs[0], output, tri_, upper_, stream());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/cpu/jit_compiler.h"
|
#include "mlx/backend/cpu/jit_compiler.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user