mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 14:34:37 +08:00
Compare commits
1 Commits
v0.20.0
...
ab-nf4-qua
Author | SHA1 | Date | |
---|---|---|---|
![]() |
152092957c |
@@ -13,62 +13,8 @@ parameters:
|
|||||||
test_release:
|
test_release:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
linux_release:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_documentation:
|
|
||||||
parameters:
|
|
||||||
upload-docs:
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
macos:
|
|
||||||
xcode: "15.2.0"
|
|
||||||
resource_class: macos.m1.medium.gen1
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install
|
|
||||||
command: |
|
|
||||||
brew install python@3.9
|
|
||||||
brew install doxygen
|
|
||||||
python3.9 -m venv env
|
|
||||||
source env/bin/activate
|
|
||||||
pip install --upgrade pip
|
|
||||||
pip install --upgrade cmake
|
|
||||||
pip install -r docs/requirements.txt
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
|
||||||
- when:
|
|
||||||
condition:
|
|
||||||
not: << parameters.upload-docs >>
|
|
||||||
steps:
|
|
||||||
- run:
|
|
||||||
name: Build documentation
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
cd docs && doxygen && make html O=-W
|
|
||||||
- when:
|
|
||||||
condition: << parameters.upload-docs >>
|
|
||||||
steps:
|
|
||||||
- add_ssh_keys:
|
|
||||||
fingerprints:
|
|
||||||
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
|
||||||
- run:
|
|
||||||
name: Upload documentation
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
git config user.email "mlx@group.apple.com"
|
|
||||||
git config user.name "CircleCI Docs"
|
|
||||||
git checkout gh-pages
|
|
||||||
git rebase main
|
|
||||||
cd docs
|
|
||||||
git rm -rf build/html
|
|
||||||
doxygen && make html O=-W
|
|
||||||
git add -f build/html
|
|
||||||
git commit -m "rebase"
|
|
||||||
git push -f origin gh-pages
|
|
||||||
|
|
||||||
linux_build_and_test:
|
linux_build_and_test:
|
||||||
docker:
|
docker:
|
||||||
- image: cimg/python:3.9
|
- image: cimg/python:3.9
|
||||||
@@ -85,24 +31,19 @@ jobs:
|
|||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.2.0
|
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||||
pip install numpy
|
pip install numpy
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||||
- run:
|
- run:
|
||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
|
||||||
python3 setup.py build_ext --inplace
|
|
||||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
|
||||||
python3 setup.py develop
|
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
echo "stubs"
|
echo "stubs"
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
@@ -111,9 +52,7 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
mkdir -p build && cd build
|
mkdir -p build && cd build && cmake .. -DMLX_BUILD_METAL=OFF && make -j
|
||||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
|
||||||
make -j `nproc`
|
|
||||||
- run:
|
- run:
|
||||||
name: Run CPP tests
|
name: Run CPP tests
|
||||||
command: ./build/tests/tests
|
command: ./build/tests/tests
|
||||||
@@ -131,13 +70,13 @@ jobs:
|
|||||||
- run:
|
- run:
|
||||||
name: Install dependencies
|
name: Install dependencies
|
||||||
command: |
|
command: |
|
||||||
brew install python@3.9
|
brew install python@3.8
|
||||||
brew install openmpi
|
brew install openmpi
|
||||||
python3.9 -m venv env
|
python3.8 -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.2.0
|
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install torch
|
pip install torch
|
||||||
pip install tensorflow
|
pip install tensorflow
|
||||||
@@ -146,12 +85,11 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
@@ -159,7 +97,7 @@ jobs:
|
|||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||||
- run:
|
- run:
|
||||||
name: Build example extension
|
name: Build example extension
|
||||||
command: |
|
command: |
|
||||||
@@ -173,7 +111,7 @@ jobs:
|
|||||||
name: Build CPP only
|
name: Build CPP only
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
mkdir -p build && cd build && cmake .. && make -j
|
||||||
- run:
|
- run:
|
||||||
name: Run CPP tests
|
name: Run CPP tests
|
||||||
command: |
|
command: |
|
||||||
@@ -183,23 +121,8 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
cd build/
|
cd build/
|
||||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
|
||||||
-DBUILD_SHARED_LIBS=ON \
|
make -j
|
||||||
-DMLX_BUILD_CPU=OFF \
|
|
||||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
|
||||||
-DMLX_BUILD_GGUF=OFF \
|
|
||||||
-DMLX_METAL_JIT=ON
|
|
||||||
make -j `sysctl -n hw.ncpu`
|
|
||||||
- run:
|
|
||||||
name: Run Python tests with JIT
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
|
||||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
|
||||||
pip install -e . -v
|
|
||||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
|
||||||
METAL_DEBUG_ERROR_MODE=0 \
|
|
||||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
|
||||||
|
|
||||||
build_release:
|
build_release:
|
||||||
parameters:
|
parameters:
|
||||||
@@ -226,7 +149,7 @@ jobs:
|
|||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.2.0
|
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||||
pip install --upgrade setuptools
|
pip install --upgrade setuptools
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install twine
|
pip install twine
|
||||||
@@ -236,20 +159,19 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
DEV_RELEASE=1 \
|
DEV_RELEASE=1 \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||||
pip install . -v
|
pip install . -v
|
||||||
- run:
|
- run:
|
||||||
name: Generate package stubs
|
name: Generate package stubs
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
- run:
|
- run:
|
||||||
name: Build Python package
|
name: Build Python package
|
||||||
command: |
|
command: |
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
<< parameters.build_env >> \
|
<< parameters.build_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||||
python -m build -w
|
python -m build -w
|
||||||
- when:
|
- when:
|
||||||
condition: << parameters.build_env >>
|
condition: << parameters.build_env >>
|
||||||
@@ -262,7 +184,7 @@ jobs:
|
|||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: dist/
|
path: dist/
|
||||||
|
|
||||||
build_linux_release:
|
build_linux_test_release:
|
||||||
parameters:
|
parameters:
|
||||||
python_version:
|
python_version:
|
||||||
type: string
|
type: string
|
||||||
@@ -291,28 +213,21 @@ jobs:
|
|||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
pip install nanobind==2.2.0
|
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||||
pip install --upgrade setuptools
|
pip install --upgrade setuptools
|
||||||
pip install numpy
|
pip install numpy
|
||||||
pip install auditwheel
|
pip install auditwheel
|
||||||
pip install patchelf
|
pip install patchelf
|
||||||
pip install build
|
pip install build
|
||||||
pip install twine
|
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||||
pip install . -v
|
pip install . -v
|
||||||
pip install typing_extensions
|
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
<< parameters.extra_env >> \
|
<< parameters.extra_env >> \
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||||
python -m build --wheel
|
python -m build --wheel
|
||||||
auditwheel show dist/*
|
auditwheel show dist/*
|
||||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||||
- run:
|
|
||||||
name: Upload package
|
|
||||||
command: |
|
|
||||||
source env/bin/activate
|
|
||||||
twine upload wheelhouse/*
|
|
||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: wheelhouse/
|
path: wheelhouse/
|
||||||
|
|
||||||
@@ -330,9 +245,8 @@ workflows:
|
|||||||
- mac_build_and_test:
|
- mac_build_and_test:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
xcode_version: ["15.0.0", "15.2.0"]
|
||||||
- linux_build_and_test
|
- linux_build_and_test
|
||||||
- build_documentation
|
|
||||||
|
|
||||||
build_pypi_release:
|
build_pypi_release:
|
||||||
when:
|
when:
|
||||||
@@ -349,17 +263,9 @@ workflows:
|
|||||||
ignore: /.*/
|
ignore: /.*/
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||||
xcode_version: ["15.0.0", "15.2.0"]
|
xcode_version: ["15.0.0", "15.2.0"]
|
||||||
build_env: ["PYPI_RELEASE=1"]
|
build_env: ["PYPI_RELEASE=1"]
|
||||||
- build_documentation:
|
|
||||||
filters:
|
|
||||||
tags:
|
|
||||||
only: /^v.*/
|
|
||||||
branches:
|
|
||||||
ignore: /.*/
|
|
||||||
upload-docs: true
|
|
||||||
|
|
||||||
prb:
|
prb:
|
||||||
when:
|
when:
|
||||||
matches:
|
matches:
|
||||||
@@ -374,7 +280,7 @@ workflows:
|
|||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
xcode_version: ["15.0.0", "15.2.0"]
|
||||||
- linux_build_and_test:
|
- linux_build_and_test:
|
||||||
requires: [ hold ]
|
requires: [ hold ]
|
||||||
nightly_build:
|
nightly_build:
|
||||||
@@ -386,7 +292,7 @@ workflows:
|
|||||||
- build_release:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||||
xcode_version: ["15.0.0", "15.2.0"]
|
xcode_version: ["15.0.0", "15.2.0"]
|
||||||
weekly_build:
|
weekly_build:
|
||||||
when:
|
when:
|
||||||
@@ -397,17 +303,17 @@ workflows:
|
|||||||
- build_release:
|
- build_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
xcode_version: ["15.0.0", "15.2.0"]
|
||||||
build_env: ["DEV_RELEASE=1"]
|
build_env: ["DEV_RELEASE=1"]
|
||||||
linux_test_release:
|
linux_test_release:
|
||||||
when:
|
when:
|
||||||
and:
|
and:
|
||||||
- equal: [ main, << pipeline.git.branch >> ]
|
- equal: [ main, << pipeline.git.branch >> ]
|
||||||
- << pipeline.parameters.linux_release >>
|
- << pipeline.parameters.test_release >>
|
||||||
jobs:
|
jobs:
|
||||||
- build_linux_release:
|
- build_linux_test_release:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||||
extra_env: ["PYPI_RELEASE=1"]
|
extra_env: ["PYPI_RELEASE=1"]
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v18.1.8
|
rev: v18.1.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 24.8.0
|
rev: 24.4.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
@@ -14,7 +14,3 @@ repos:
|
|||||||
- id: isort
|
- id: isort
|
||||||
args:
|
args:
|
||||||
- --profile=black
|
- --profile=black
|
||||||
- repo: https://github.com/cheshirekow/cmake-format-precommit
|
|
||||||
rev: v0.6.13
|
|
||||||
hooks:
|
|
||||||
- id: cmake-format
|
|
||||||
|
@@ -7,18 +7,16 @@ with a short description of your contribution(s) below. For example:
|
|||||||
|
|
||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
24
CITATION.cff
24
CITATION.cff
@@ -1,24 +0,0 @@
|
|||||||
cff-version: 1.2.0
|
|
||||||
title: mlx
|
|
||||||
message: >-
|
|
||||||
If you use this software, please cite it using the
|
|
||||||
metadata from this file.
|
|
||||||
type: software
|
|
||||||
authors:
|
|
||||||
- given-names: Awni
|
|
||||||
family-names: Hannun
|
|
||||||
affiliation: Apple
|
|
||||||
- given-names: Jagrit
|
|
||||||
family-names: Digani
|
|
||||||
affiliation: Apple
|
|
||||||
- given-names: Angelos
|
|
||||||
family-names: Katharopoulos
|
|
||||||
affiliation: Apple
|
|
||||||
- given-names: Ronan
|
|
||||||
family-names: Collobert
|
|
||||||
affiliation: Apple
|
|
||||||
repository-code: 'https://github.com/ml-explore'
|
|
||||||
abstract: >-
|
|
||||||
MLX: efficient and flexible machine learning on Apple
|
|
||||||
silicon
|
|
||||||
license: MIT
|
|
234
CMakeLists.txt
234
CMakeLists.txt
@@ -24,43 +24,35 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
|||||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||||
|
|
||||||
if(NOT MLX_VERSION)
|
if(NOT MLX_VERSION)
|
||||||
set(MLX_VERSION 0.20.0)
|
set(MLX_VERSION 0.15.1)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# --------------------- Processor tests -------------------------
|
# --------------------- Processor tests -------------------------
|
||||||
|
|
||||||
message(
|
message(STATUS "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}")
|
||||||
STATUS
|
|
||||||
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
|
||||||
)
|
|
||||||
|
|
||||||
set(MLX_BUILD_ARM OFF)
|
set(MLX_BUILD_ARM OFF)
|
||||||
|
|
||||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||||
if(NOT MLX_ENABLE_X64_MAC)
|
if(NOT MLX_ENABLE_X64_MAC)
|
||||||
message(
|
message(FATAL_ERROR
|
||||||
FATAL_ERROR
|
"Building for x86_64 on macOS is not supported."
|
||||||
"Building for x86_64 on macOS is not supported."
|
" If you are on an Apple silicon system, check the build"
|
||||||
" If you are on an Apple silicon system, check the build"
|
" documentation for possible fixes: "
|
||||||
" documentation for possible fixes: "
|
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
||||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
|
||||||
)
|
|
||||||
else()
|
else()
|
||||||
set(MLX_BUILD_METAL OFF)
|
|
||||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||||
endif()
|
endif()
|
||||||
|
set(MLX_BUILD_METAL OFF)
|
||||||
|
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||||
|
set(MLX_BUILD_ARM ON)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
set(MLX_BUILD_METAL OFF)
|
|
||||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
|
|
||||||
set(MLX_BUILD_ARM ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# ----------------------------- Lib -----------------------------
|
# ----------------------------- Lib -----------------------------
|
||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
@@ -69,59 +61,64 @@ cmake_policy(SET CMP0135 NEW)
|
|||||||
|
|
||||||
add_library(mlx)
|
add_library(mlx)
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
if (MLX_BUILD_METAL)
|
||||||
set(METAL_LIB "-framework Metal")
|
find_library(METAL_LIB Metal)
|
||||||
set(FOUNDATION_LIB "-framework Foundation")
|
find_library(FOUNDATION_LIB Foundation)
|
||||||
set(QUARTZ_LIB "-framework QuartzCore")
|
find_library(QUARTZ_LIB QuartzCore)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
if (MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||||
message(STATUS "Metal not found. Unable to build GPU")
|
message(STATUS "Metal not found. Unable to build GPU")
|
||||||
set(MLX_BUILD_METAL OFF)
|
set(MLX_BUILD_METAL OFF)
|
||||||
set(MLX_METAL_DEBUG OFF)
|
set(MLX_METAL_DEBUG OFF)
|
||||||
elseif(MLX_BUILD_METAL)
|
elseif (MLX_BUILD_METAL)
|
||||||
message(STATUS "Building METAL sources")
|
message(STATUS "Building METAL sources")
|
||||||
|
|
||||||
if(MLX_METAL_DEBUG)
|
if (MLX_METAL_DEBUG)
|
||||||
add_compile_definitions(MLX_METAL_DEBUG)
|
add_compile_definitions(MLX_METAL_DEBUG)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Throw an error if xcrun not found
|
# Throw an error if xcrun not found
|
||||||
execute_process(
|
execute_process(COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||||
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
OUTPUT_VARIABLE MACOS_VERSION
|
||||||
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
COMMAND_ERROR_IS_FATAL ANY)
|
||||||
|
|
||||||
if(${MACOS_VERSION} LESS 14.0)
|
|
||||||
message(
|
|
||||||
FATAL_ERROR
|
|
||||||
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
|
||||||
endif()
|
|
||||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||||
|
|
||||||
set(METAL_CPP_URL
|
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
|
||||||
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
|
if (${MACOS_VERSION} GREATER_EQUAL 15.0)
|
||||||
)
|
set(MLX_METAL_VERSION METAL_3_2)
|
||||||
# Get the metal version
|
elseif (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||||
execute_process(
|
set(MLX_METAL_VERSION METAL_3_1)
|
||||||
COMMAND
|
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||||
zsh "-c"
|
set(MLX_METAL_VERSION METAL_3_0)
|
||||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
else()
|
||||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||||
|
endif()
|
||||||
|
|
||||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
FetchContent_Declare(
|
||||||
|
metal_cpp
|
||||||
|
URL ${METAL_CPP_URL}
|
||||||
|
)
|
||||||
|
|
||||||
FetchContent_MakeAvailable(metal_cpp)
|
FetchContent_MakeAvailable(metal_cpp)
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
mlx PUBLIC
|
||||||
$<INSTALL_INTERFACE:include/metal_cpp>)
|
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
||||||
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
$<INSTALL_INTERFACE:include/metal_cpp>
|
||||||
|
)
|
||||||
|
target_link_libraries(
|
||||||
|
mlx PUBLIC
|
||||||
|
${METAL_LIB}
|
||||||
|
${FOUNDATION_LIB}
|
||||||
|
${QUARTZ_LIB})
|
||||||
|
|
||||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
add_compile_definitions(${MLX_METAL_VERSION})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_CPU)
|
if (MLX_BUILD_CPU)
|
||||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||||
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||||
set(MLX_BUILD_ACCELERATE ON)
|
set(MLX_BUILD_ACCELERATE ON)
|
||||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||||
@@ -133,29 +130,32 @@ if(MLX_BUILD_CPU)
|
|||||||
# The blas shipped in macOS SDK is not supported, search homebrew for
|
# The blas shipped in macOS SDK is not supported, search homebrew for
|
||||||
# openblas instead.
|
# openblas instead.
|
||||||
set(BLA_VENDOR OpenBLAS)
|
set(BLA_VENDOR OpenBLAS)
|
||||||
set(LAPACK_ROOT
|
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
||||||
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
|
||||||
endif()
|
endif()
|
||||||
# Search and link with lapack.
|
# Search and link with lapack.
|
||||||
find_package(LAPACK REQUIRED)
|
find_package(LAPACK REQUIRED)
|
||||||
if(NOT LAPACK_FOUND)
|
if (NOT LAPACK_FOUND)
|
||||||
message(FATAL_ERROR "Must have LAPACK installed")
|
message(FATAL_ERROR "Must have LAPACK installed")
|
||||||
endif()
|
endif()
|
||||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
|
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||||
/usr/local/opt/openblas/include)
|
/usr/include
|
||||||
|
/usr/local/include
|
||||||
|
/usr/local/opt/openblas/include)
|
||||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||||
# List blas after lapack otherwise we may accidentally incldue an old
|
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||||
# version of lapack.h from the include dirs of blas.
|
# of lapack.h from the include dirs of blas.
|
||||||
find_package(BLAS REQUIRED)
|
find_package(BLAS REQUIRED)
|
||||||
if(NOT BLAS_FOUND)
|
if (NOT BLAS_FOUND)
|
||||||
message(FATAL_ERROR "Must have BLAS installed")
|
message(FATAL_ERROR "Must have BLAS installed")
|
||||||
endif()
|
endif()
|
||||||
# TODO find a cleaner way to do this
|
# TODO find a cleaner way to do this
|
||||||
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
|
find_path(BLAS_INCLUDE_DIRS cblas.h
|
||||||
$ENV{BLAS_HOME}/include)
|
/usr/include
|
||||||
|
/usr/local/include
|
||||||
|
$ENV{BLAS_HOME}/include)
|
||||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||||
@@ -166,95 +166,96 @@ else()
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_package(MPI)
|
find_package(MPI)
|
||||||
if(MPI_FOUND)
|
if (MPI_FOUND)
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND zsh "-c" "mpirun --version"
|
COMMAND zsh "-c" "mpirun --version"
|
||||||
OUTPUT_VARIABLE MPI_VERSION
|
OUTPUT_VARIABLE MPI_VERSION
|
||||||
ERROR_QUIET)
|
COMMAND_ERROR_IS_FATAL ANY
|
||||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
)
|
||||||
|
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||||
elseif(MPI_VERSION STREQUAL "")
|
|
||||||
set(MPI_FOUND FALSE)
|
|
||||||
message(
|
|
||||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
|
||||||
else()
|
else()
|
||||||
set(MPI_FOUND FALSE)
|
message(
|
||||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
WARNING
|
||||||
endif()
|
"MPI which is not OpenMPI found. Building without MPI."
|
||||||
|
)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||||
|
|
||||||
target_include_directories(
|
target_include_directories(
|
||||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
mlx
|
||||||
$<INSTALL_INTERFACE:include>)
|
PUBLIC
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||||
|
$<INSTALL_INTERFACE:include>
|
||||||
|
)
|
||||||
|
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(fmt
|
||||||
fmt
|
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||||
GIT_TAG 10.2.1
|
GIT_TAG 10.2.1
|
||||||
EXCLUDE_FROM_ALL)
|
EXCLUDE_FROM_ALL
|
||||||
|
)
|
||||||
FetchContent_MakeAvailable(fmt)
|
FetchContent_MakeAvailable(fmt)
|
||||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
|
||||||
|
|
||||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||||
message(STATUS "Building Python bindings.")
|
message(STATUS "Building Python bindings.")
|
||||||
find_package(
|
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||||
Python 3.8
|
|
||||||
COMPONENTS Interpreter Development.Module
|
|
||||||
REQUIRED)
|
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||||
OUTPUT_VARIABLE NB_DIR)
|
|
||||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||||
find_package(nanobind CONFIG REQUIRED)
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_TESTS)
|
if (MLX_BUILD_TESTS)
|
||||||
include(CTest)
|
include(CTest)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_EXAMPLES)
|
if (MLX_BUILD_EXAMPLES)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_BENCHMARKS)
|
if (MLX_BUILD_BENCHMARKS)
|
||||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------- Installation -----------------------------
|
# ----------------------------- Installation -----------------------------
|
||||||
include(GNUInstallDirs)
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
# Install library
|
# Install library
|
||||||
install(
|
install(
|
||||||
TARGETS mlx
|
TARGETS mlx
|
||||||
EXPORT MLXTargets
|
EXPORT MLXTargets
|
||||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||||
INCLUDES
|
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
)
|
||||||
|
|
||||||
|
|
||||||
# Install headers
|
# Install headers
|
||||||
install(
|
install(
|
||||||
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||||
COMPONENT headers
|
COMPONENT headers
|
||||||
FILES_MATCHING
|
FILES_MATCHING PATTERN "*.h"
|
||||||
PATTERN "*.h"
|
)
|
||||||
PATTERN "backend/metal/kernels.h" EXCLUDE)
|
|
||||||
|
|
||||||
# Install metal dependencies
|
# Install metal dependencies
|
||||||
if(MLX_BUILD_METAL)
|
if (MLX_BUILD_METAL)
|
||||||
|
|
||||||
# Install metal cpp
|
# Install metal cpp
|
||||||
install(
|
install(
|
||||||
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
||||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
||||||
COMPONENT metal_cpp_source)
|
COMPONENT metal_cpp_source
|
||||||
|
)
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -266,24 +267,31 @@ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
|||||||
install(
|
install(
|
||||||
EXPORT MLXTargets
|
EXPORT MLXTargets
|
||||||
FILE MLXTargets.cmake
|
FILE MLXTargets.cmake
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
include(CMakePackageConfigHelpers)
|
include(CMakePackageConfigHelpers)
|
||||||
|
|
||||||
write_basic_package_version_file(
|
write_basic_package_version_file(
|
||||||
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||||
COMPATIBILITY SameMajorVersion
|
COMPATIBILITY SameMajorVersion
|
||||||
VERSION ${MLX_VERSION})
|
VERSION ${MLX_VERSION}
|
||||||
|
)
|
||||||
|
|
||||||
configure_package_config_file(
|
configure_package_config_file(
|
||||||
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
|
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
|
||||||
|
${MLX_CMAKE_BUILD_CONFIG}
|
||||||
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||||
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
|
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
|
||||||
MLX_CMAKE_INSTALL_MODULE_DIR)
|
)
|
||||||
|
|
||||||
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
install(
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
||||||
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
install(DIRECTORY ${CMAKE_MODULE_PATH}/
|
install(
|
||||||
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
DIRECTORY ${CMAKE_MODULE_PATH}/
|
||||||
|
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
||||||
|
)
|
||||||
|
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
[](https://circleci.com/gh/ml-explore/mlx)
|
[](https://circleci.com/gh/ml-explore/mlx)
|
||||||
|
|
||||||
MLX is an array framework for machine learning on Apple silicon,
|
MLX is an array framework for machine learning research on Apple silicon,
|
||||||
brought to you by Apple machine learning research.
|
brought to you by Apple machine learning research.
|
||||||
|
|
||||||
Some key features of MLX include:
|
Some key features of MLX include:
|
||||||
|
@@ -62,10 +62,17 @@ def matmul(x, y):
|
|||||||
|
|
||||||
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(10):
|
for i in range(100):
|
||||||
ys.append(
|
ys.append(
|
||||||
mx.quantized_matmul(
|
mx.quantized_matmul(
|
||||||
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
|
x,
|
||||||
|
w,
|
||||||
|
s,
|
||||||
|
b,
|
||||||
|
transpose=transpose,
|
||||||
|
group_size=group_size,
|
||||||
|
bits=bits,
|
||||||
|
mode=mx.QuantizationMode.DEFAULT,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
@@ -144,13 +151,6 @@ def reduction(op, axis, x):
|
|||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
|
||||||
|
|
||||||
def sum_and_add(axis, x, y):
|
|
||||||
z = x.sum(axis=axis, keepdims=True)
|
|
||||||
for i in range(50):
|
|
||||||
z = (z + y).sum(axis=axis, keepdims=True)
|
|
||||||
mx.eval(z)
|
|
||||||
|
|
||||||
|
|
||||||
def softmax(axis, x):
|
def softmax(axis, x):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@@ -512,8 +512,5 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "selu":
|
elif args.benchmark == "selu":
|
||||||
print(bench(selu, x))
|
print(bench(selu, x))
|
||||||
|
|
||||||
elif args.benchmark == "sum_and_add":
|
|
||||||
print(bench(sum_and_add, axis, *xs))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown benchmark")
|
raise ValueError("Unknown benchmark")
|
||||||
|
@@ -1,127 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
N_warmup = 1
|
|
||||||
N_iter_bench = 10
|
|
||||||
N_iter_func = 5
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
|
|
||||||
|
|
||||||
def bench(f, a, b):
|
|
||||||
for i in range(N_warmup):
|
|
||||||
f(a, b)
|
|
||||||
|
|
||||||
s = time.perf_counter_ns()
|
|
||||||
for i in range(N_iter_bench):
|
|
||||||
f(a, b)
|
|
||||||
e = time.perf_counter_ns()
|
|
||||||
return (e - s) * 1e-9
|
|
||||||
|
|
||||||
|
|
||||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
def mx_conv_2D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
|
||||||
ys.append(y)
|
|
||||||
mx.eval(ys)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return mx_conv_2D
|
|
||||||
|
|
||||||
|
|
||||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
@torch.no_grad()
|
|
||||||
def pt_conv_2D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
|
||||||
ys.append(y)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return pt_conv_2D
|
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
|
||||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
|
||||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
|
||||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
|
||||||
np_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
a_mx = mx.array(a_np)
|
|
||||||
b_mx = mx.array(b_np)
|
|
||||||
|
|
||||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
|
||||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
|
|
||||||
|
|
||||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
|
||||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
|
||||||
|
|
||||||
time_torch = bench(f_pt, a_pt, b_pt)
|
|
||||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
|
||||||
|
|
||||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
|
||||||
out_pt = torch.conv2d(
|
|
||||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
|
||||||
out_pt = out_pt.numpy(force=True)
|
|
||||||
|
|
||||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
|
||||||
|
|
||||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
|
||||||
print(
|
|
||||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return time_mlx, time_torch
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
|
||||||
|
|
||||||
dtypes = ("float32",)
|
|
||||||
shapes = (
|
|
||||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
|
||||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
|
||||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
|
||||||
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
|
||||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
|
||||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
|
||||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
for dtype in dtypes:
|
|
||||||
print(
|
|
||||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
|
||||||
)
|
|
||||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
|
||||||
np_dtype = getattr(np, dtype)
|
|
||||||
time_mlx, time_torch = bench_shape(
|
|
||||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
|
||||||
)
|
|
||||||
diff = time_torch / time_mlx - 1.0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
|
||||||
)
|
|
||||||
if time_mlx >= 2.0 * time_torch:
|
|
||||||
print("ATTENTION ^^^^^^^")
|
|
@@ -1,143 +0,0 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn
|
|
||||||
import mlx.optimizers as opt
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def bench_mlx(steps: int = 20) -> float:
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
|
|
||||||
class BenchNetMLX(mlx.nn.Module):
|
|
||||||
# simple encoder-decoder net
|
|
||||||
|
|
||||||
def __init__(self, in_channels, hidden_channels=32):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.net = mlx.nn.Sequential(
|
|
||||||
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
|
||||||
mlx.nn.ReLU(),
|
|
||||||
mlx.nn.Conv2d(
|
|
||||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
mlx.nn.ReLU(),
|
|
||||||
mlx.nn.ConvTranspose2d(
|
|
||||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
mlx.nn.ReLU(),
|
|
||||||
mlx.nn.ConvTranspose2d(
|
|
||||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, input):
|
|
||||||
return self.net(input)
|
|
||||||
|
|
||||||
benchNet = BenchNetMLX(3)
|
|
||||||
mx.eval(benchNet.parameters())
|
|
||||||
optim = opt.Adam(learning_rate=1e-3)
|
|
||||||
|
|
||||||
inputs = mx.random.normal([10, 256, 256, 3])
|
|
||||||
|
|
||||||
params = benchNet.parameters()
|
|
||||||
optim.init(params)
|
|
||||||
|
|
||||||
state = [benchNet.state, optim.state]
|
|
||||||
|
|
||||||
def loss_fn(params, image):
|
|
||||||
benchNet.update(params)
|
|
||||||
pred_image = benchNet(image)
|
|
||||||
return (pred_image - image).abs().mean()
|
|
||||||
|
|
||||||
def step(params, image):
|
|
||||||
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
|
||||||
optim.update(benchNet, grads)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
total_time = 0.0
|
|
||||||
print("MLX:")
|
|
||||||
for i in range(steps):
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
step(benchNet.parameters(), inputs)
|
|
||||||
mx.eval(state)
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
|
|
||||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
|
||||||
total_time += (end_time - start_time) * 1000
|
|
||||||
|
|
||||||
return total_time
|
|
||||||
|
|
||||||
|
|
||||||
def bench_torch(steps: int = 20) -> float:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
class BenchNetTorch(torch.nn.Module):
|
|
||||||
# simple encoder-decoder net
|
|
||||||
|
|
||||||
def __init__(self, in_channels, hidden_channels=32):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.net = torch.nn.Sequential(
|
|
||||||
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Conv2d(
|
|
||||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.ConvTranspose2d(
|
|
||||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.ConvTranspose2d(
|
|
||||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return self.net(input)
|
|
||||||
|
|
||||||
benchNet = BenchNetTorch(3).to(device)
|
|
||||||
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
|
||||||
|
|
||||||
inputs = torch.randn(10, 3, 256, 256, device=device)
|
|
||||||
|
|
||||||
def loss_fn(pred_image, image):
|
|
||||||
return (pred_image - image).abs().mean()
|
|
||||||
|
|
||||||
total_time = 0.0
|
|
||||||
print("PyTorch:")
|
|
||||||
for i in range(steps):
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
optim.zero_grad()
|
|
||||||
pred_image = benchNet(inputs)
|
|
||||||
loss = loss_fn(pred_image, inputs)
|
|
||||||
loss.backward()
|
|
||||||
optim.step()
|
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
|
|
||||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
|
||||||
total_time += (end_time - start_time) * 1000
|
|
||||||
|
|
||||||
return total_time
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
steps = 20
|
|
||||||
time_mlx = bench_mlx(steps)
|
|
||||||
time_torch = bench_torch(steps)
|
|
||||||
|
|
||||||
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
|
||||||
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
|
||||||
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
|
||||||
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
|
||||||
|
|
||||||
diff = time_torch / time_mlx - 1.0
|
|
||||||
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@@ -1,129 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
N_warmup = 1
|
|
||||||
N_iter_bench = 10
|
|
||||||
N_iter_func = 5
|
|
||||||
|
|
||||||
|
|
||||||
def bench(f, a, b):
|
|
||||||
for i in range(N_warmup):
|
|
||||||
f(a, b)
|
|
||||||
|
|
||||||
s = time.perf_counter_ns()
|
|
||||||
for i in range(N_iter_bench):
|
|
||||||
f(a, b)
|
|
||||||
e = time.perf_counter_ns()
|
|
||||||
return (e - s) * 1e-9
|
|
||||||
|
|
||||||
|
|
||||||
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
def mx_conv_transpose_2D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = mx.conv_transpose2d(
|
|
||||||
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
|
||||||
)
|
|
||||||
ys.append(y)
|
|
||||||
mx.eval(ys)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return mx_conv_transpose_2D
|
|
||||||
|
|
||||||
|
|
||||||
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
@torch.no_grad()
|
|
||||||
def pt_conv_transpose_2D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = torch.conv_transpose2d(
|
|
||||||
a, b, stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
ys.append(y)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return pt_conv_transpose_2D
|
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
|
||||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
|
||||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
|
||||||
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
|
|
||||||
np_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
a_mx = mx.array(a_np)
|
|
||||||
b_mx = mx.array(b_np)
|
|
||||||
|
|
||||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
|
||||||
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
|
|
||||||
|
|
||||||
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
|
||||||
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
|
||||||
|
|
||||||
time_torch = bench(f_pt, a_pt, b_pt)
|
|
||||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
|
||||||
|
|
||||||
out_mx = mx.conv_transpose2d(
|
|
||||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
|
||||||
)
|
|
||||||
out_pt = torch.conv_transpose2d(
|
|
||||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
|
||||||
out_pt = out_pt.numpy(force=True)
|
|
||||||
|
|
||||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
|
||||||
|
|
||||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
|
||||||
print(
|
|
||||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return time_mlx, time_torch
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
|
||||||
|
|
||||||
dtypes = ("float32",)
|
|
||||||
shapes = (
|
|
||||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
|
||||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
|
||||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
for dtype in dtypes:
|
|
||||||
print(
|
|
||||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
|
||||||
)
|
|
||||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
|
||||||
np_dtype = getattr(np, dtype)
|
|
||||||
time_mlx, time_torch = bench_shape(
|
|
||||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
|
||||||
)
|
|
||||||
diff = time_torch / time_mlx - 1.0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
|
||||||
)
|
|
||||||
if time_mlx >= 2.0 * time_torch:
|
|
||||||
print("ATTENTION ^^^^^^^")
|
|
@@ -1,110 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
N_warmup = 1
|
|
||||||
N_iter_bench = 10
|
|
||||||
N_iter_func = 5
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
|
|
||||||
|
|
||||||
def bench(f, a, b):
|
|
||||||
for i in range(N_warmup):
|
|
||||||
f(a, b)
|
|
||||||
|
|
||||||
s = time.perf_counter_ns()
|
|
||||||
for i in range(N_iter_bench):
|
|
||||||
f(a, b)
|
|
||||||
e = time.perf_counter_ns()
|
|
||||||
return (e - s) * 1e-9
|
|
||||||
|
|
||||||
|
|
||||||
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
def mx_conv_3D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
|
||||||
ys.append(y)
|
|
||||||
mx.eval(ys)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return mx_conv_3D
|
|
||||||
|
|
||||||
|
|
||||||
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
@torch.no_grad()
|
|
||||||
def pt_conv_3D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
|
||||||
ys.append(y)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return pt_conv_3D
|
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
|
||||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
|
||||||
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
|
||||||
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
|
||||||
np_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
a_mx = mx.array(a_np)
|
|
||||||
b_mx = mx.array(b_np)
|
|
||||||
|
|
||||||
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
|
||||||
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
|
||||||
|
|
||||||
f_mx = make_mx_conv_3D(strides, padding, groups)
|
|
||||||
f_pt = make_pt_conv_3D(strides, padding, groups)
|
|
||||||
|
|
||||||
time_torch = bench(f_pt, a_pt, b_pt)
|
|
||||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
|
||||||
|
|
||||||
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
|
||||||
out_pt = torch.conv3d(
|
|
||||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
|
||||||
out_pt = out_pt.numpy(force=True)
|
|
||||||
|
|
||||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
|
||||||
|
|
||||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
|
||||||
print(
|
|
||||||
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return time_mlx, time_torch
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
|
||||||
|
|
||||||
dtypes = ("float32",)
|
|
||||||
shapes = (
|
|
||||||
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
|
||||||
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
for dtype in dtypes:
|
|
||||||
print(
|
|
||||||
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
|
||||||
)
|
|
||||||
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
|
||||||
np_dtype = getattr(np, dtype)
|
|
||||||
time_mlx, time_torch = bench_shape(
|
|
||||||
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
|
||||||
)
|
|
||||||
diff = time_torch / time_mlx - 1.0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
|
||||||
)
|
|
||||||
if time_mlx >= 2.0 * time_torch:
|
|
||||||
print("ATTENTION ^^^^^^^")
|
|
@@ -1,143 +0,0 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import mlx.nn
|
|
||||||
import mlx.optimizers as opt
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
|
|
||||||
class BenchNetMLX(mlx.nn.Module):
|
|
||||||
# simple encoder-decoder net
|
|
||||||
|
|
||||||
def __init__(self, in_channels, hidden_channels=16):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.net = mlx.nn.Sequential(
|
|
||||||
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
|
||||||
mlx.nn.ReLU(),
|
|
||||||
mlx.nn.Conv3d(
|
|
||||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
mlx.nn.ReLU(),
|
|
||||||
mlx.nn.ConvTranspose3d(
|
|
||||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
mlx.nn.ReLU(),
|
|
||||||
mlx.nn.ConvTranspose3d(
|
|
||||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, input):
|
|
||||||
return self.net(input)
|
|
||||||
|
|
||||||
benchNet = BenchNetMLX(3)
|
|
||||||
mx.eval(benchNet.parameters())
|
|
||||||
optim = opt.Adam(learning_rate=1e-3)
|
|
||||||
|
|
||||||
inputs = mx.random.normal(shape)
|
|
||||||
|
|
||||||
params = benchNet.parameters()
|
|
||||||
optim.init(params)
|
|
||||||
|
|
||||||
state = [benchNet.state, optim.state]
|
|
||||||
|
|
||||||
def loss_fn(params, image):
|
|
||||||
benchNet.update(params)
|
|
||||||
pred_image = benchNet(image)
|
|
||||||
return (pred_image - image).abs().mean()
|
|
||||||
|
|
||||||
def step(params, image):
|
|
||||||
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
|
||||||
optim.update(benchNet, grads)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
total_time = 0.0
|
|
||||||
print("MLX:")
|
|
||||||
for i in range(steps):
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
step(benchNet.parameters(), inputs)
|
|
||||||
mx.eval(state)
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
|
|
||||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
|
||||||
total_time += (end_time - start_time) * 1000
|
|
||||||
|
|
||||||
return total_time
|
|
||||||
|
|
||||||
|
|
||||||
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
class BenchNetTorch(torch.nn.Module):
|
|
||||||
# simple encoder-decoder net
|
|
||||||
|
|
||||||
def __init__(self, in_channels, hidden_channels=16):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.net = torch.nn.Sequential(
|
|
||||||
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.Conv3d(
|
|
||||||
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.ConvTranspose3d(
|
|
||||||
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
torch.nn.ConvTranspose3d(
|
|
||||||
hidden_channels, in_channels, kernel_size=3, padding=1
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return self.net(input)
|
|
||||||
|
|
||||||
benchNet = BenchNetTorch(3).to(device)
|
|
||||||
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
|
||||||
|
|
||||||
inputs = torch.randn(*shape, device=device)
|
|
||||||
|
|
||||||
def loss_fn(pred_image, image):
|
|
||||||
return (pred_image - image).abs().mean()
|
|
||||||
|
|
||||||
total_time = 0.0
|
|
||||||
print("PyTorch:")
|
|
||||||
for i in range(steps):
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
optim.zero_grad()
|
|
||||||
pred_image = benchNet(inputs)
|
|
||||||
loss = loss_fn(pred_image, inputs)
|
|
||||||
loss.backward()
|
|
||||||
optim.step()
|
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
|
|
||||||
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
|
||||||
total_time += (end_time - start_time) * 1000
|
|
||||||
|
|
||||||
return total_time
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
steps = 10
|
|
||||||
time_mlx = bench_mlx(steps)
|
|
||||||
time_torch = bench_torch(steps)
|
|
||||||
|
|
||||||
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
|
||||||
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
|
||||||
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
|
||||||
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
|
||||||
|
|
||||||
diff = time_torch / time_mlx - 1.0
|
|
||||||
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@@ -1,116 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
N_warmup = 1
|
|
||||||
N_iter_bench = 10
|
|
||||||
N_iter_func = 5
|
|
||||||
mx.set_default_device(mx.cpu)
|
|
||||||
|
|
||||||
|
|
||||||
def bench(f, a, b):
|
|
||||||
for i in range(N_warmup):
|
|
||||||
f(a, b)
|
|
||||||
|
|
||||||
s = time.perf_counter_ns()
|
|
||||||
for i in range(N_iter_bench):
|
|
||||||
f(a, b)
|
|
||||||
e = time.perf_counter_ns()
|
|
||||||
return (e - s) * 1e-9
|
|
||||||
|
|
||||||
|
|
||||||
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
|
||||||
def mx_conv_3D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = mx.conv_transpose3d(
|
|
||||||
a, b, stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
ys.append(y)
|
|
||||||
mx.eval(ys)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return mx_conv_3D
|
|
||||||
|
|
||||||
|
|
||||||
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
|
||||||
@torch.no_grad()
|
|
||||||
def pt_conv_3D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = torch.conv_transpose3d(
|
|
||||||
a, b, stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
ys.append(y)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return pt_conv_3D
|
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
|
||||||
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
|
||||||
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
|
||||||
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
|
||||||
np_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
a_mx = mx.array(a_np)
|
|
||||||
b_mx = mx.array(b_np)
|
|
||||||
|
|
||||||
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
|
||||||
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
|
|
||||||
|
|
||||||
f_mx = make_mx_conv_3D(strides, padding, groups)
|
|
||||||
f_pt = make_pt_conv_3D(strides, padding, groups)
|
|
||||||
|
|
||||||
time_torch = bench(f_pt, a_pt, b_pt)
|
|
||||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
|
||||||
|
|
||||||
out_mx = mx.conv_transpose3d(
|
|
||||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
out_pt = torch.conv_transpose3d(
|
|
||||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
|
||||||
out_pt = out_pt.numpy(force=True)
|
|
||||||
|
|
||||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
|
||||||
|
|
||||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
|
||||||
print(
|
|
||||||
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return time_mlx, time_torch
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
|
||||||
|
|
||||||
dtypes = ("float32",)
|
|
||||||
shapes = (
|
|
||||||
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
|
||||||
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
for dtype in dtypes:
|
|
||||||
print(
|
|
||||||
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
|
||||||
)
|
|
||||||
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
|
||||||
np_dtype = getattr(np, dtype)
|
|
||||||
time_mlx, time_torch = bench_shape(
|
|
||||||
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
|
||||||
)
|
|
||||||
diff = time_torch / time_mlx - 1.0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
|
||||||
)
|
|
||||||
if time_mlx >= 2.0 * time_torch:
|
|
||||||
print("ATTENTION ^^^^^^^")
|
|
@@ -1,135 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
N_warmup = 10
|
|
||||||
N_iter_bench = 100
|
|
||||||
N_iter_func = 5
|
|
||||||
|
|
||||||
|
|
||||||
def bench(f, a, b):
|
|
||||||
for i in range(N_warmup):
|
|
||||||
f(a, b)
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
s = time.perf_counter_ns()
|
|
||||||
for i in range(N_iter_bench):
|
|
||||||
f(a, b)
|
|
||||||
e = time.perf_counter_ns()
|
|
||||||
return (e - s) * 1e-9
|
|
||||||
|
|
||||||
|
|
||||||
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
def mx_conv_transpose_2D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = mx.conv_transpose2d(
|
|
||||||
a, b, stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
ys.append(y)
|
|
||||||
mx.eval(ys)
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return mx_conv_transpose_2D
|
|
||||||
|
|
||||||
|
|
||||||
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
|
||||||
@torch.no_grad()
|
|
||||||
def pt_conv_transpose_2D(a, b):
|
|
||||||
ys = []
|
|
||||||
for i in range(N_iter_func):
|
|
||||||
y = torch.conv_transpose2d(
|
|
||||||
a, b, stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
ys.append(y)
|
|
||||||
torch.mps.synchronize()
|
|
||||||
return ys
|
|
||||||
|
|
||||||
return pt_conv_transpose_2D
|
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
|
||||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
|
||||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
|
||||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
|
||||||
np_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
a_mx = mx.array(a_np)
|
|
||||||
b_mx = mx.array(b_np)
|
|
||||||
|
|
||||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
|
||||||
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
|
|
||||||
|
|
||||||
torch.mps.synchronize()
|
|
||||||
|
|
||||||
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
|
||||||
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
|
||||||
|
|
||||||
time_torch = bench(f_pt, a_pt, b_pt)
|
|
||||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
|
||||||
|
|
||||||
out_mx = mx.conv_transpose2d(
|
|
||||||
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
out_pt = torch.conv_transpose2d(
|
|
||||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
|
||||||
)
|
|
||||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
|
||||||
out_pt = out_pt.numpy(force=True)
|
|
||||||
|
|
||||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
|
||||||
|
|
||||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
|
||||||
print(
|
|
||||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return time_mlx, time_torch
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
|
||||||
|
|
||||||
dtypes = ("float32",)
|
|
||||||
shapes = (
|
|
||||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
|
||||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
|
||||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
|
||||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
|
||||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
|
||||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
for dtype in dtypes:
|
|
||||||
print(
|
|
||||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
|
||||||
)
|
|
||||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
|
||||||
np_dtype = getattr(np, dtype)
|
|
||||||
time_mlx, time_torch = bench_shape(
|
|
||||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
|
||||||
)
|
|
||||||
diff = time_torch / time_mlx - 1.0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
|
||||||
)
|
|
||||||
if time_mlx >= 2.0 * time_torch:
|
|
||||||
print("ATTENTION ^^^^^^^")
|
|
@@ -1,66 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Run with:
|
|
||||||
mpirun -n 2 python /path/to/distributed_bench.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
|
|
||||||
|
|
||||||
def time_fn(fn, *args, **kwargs):
|
|
||||||
msg = kwargs.pop("msg", None)
|
|
||||||
world = mx.distributed.init()
|
|
||||||
if world.rank() == 0:
|
|
||||||
if msg:
|
|
||||||
print(f"Timing {msg} ...", end=" ")
|
|
||||||
else:
|
|
||||||
print(f"Timing {fn.__name__} ...", end=" ")
|
|
||||||
|
|
||||||
# warmup
|
|
||||||
for _ in range(5):
|
|
||||||
mx.eval(fn(*args, **kwargs))
|
|
||||||
|
|
||||||
num_iters = 100
|
|
||||||
tic = time.perf_counter()
|
|
||||||
for _ in range(num_iters):
|
|
||||||
x = mx.eval(fn(*args, **kwargs))
|
|
||||||
toc = time.perf_counter()
|
|
||||||
|
|
||||||
msec = 1e3 * (toc - tic) / num_iters
|
|
||||||
if world.rank() == 0:
|
|
||||||
print(f"{msec:.5f} msec")
|
|
||||||
|
|
||||||
|
|
||||||
def time_all_sum():
|
|
||||||
shape = (4096,)
|
|
||||||
x = mx.random.uniform(shape=shape)
|
|
||||||
mx.eval(x)
|
|
||||||
|
|
||||||
def sine(x):
|
|
||||||
for _ in range(20):
|
|
||||||
x = mx.sin(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(sine, x)
|
|
||||||
|
|
||||||
def all_sum_plain(x):
|
|
||||||
for _ in range(20):
|
|
||||||
x = mx.distributed.all_sum(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(all_sum_plain, x)
|
|
||||||
|
|
||||||
def all_sum_with_sine(x):
|
|
||||||
for _ in range(20):
|
|
||||||
x = mx.sin(x)
|
|
||||||
x = mx.distributed.all_sum(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
time_fn(all_sum_with_sine, x)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
time_all_sum()
|
|
@@ -1,84 +0,0 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def timeit(fn, its=100, args=[]):
|
|
||||||
for _ in range(5):
|
|
||||||
fn(*args)
|
|
||||||
tic = time.perf_counter()
|
|
||||||
for _ in range(its):
|
|
||||||
fn(*args)
|
|
||||||
toc = time.perf_counter()
|
|
||||||
return 1e3 * (toc - tic) / its
|
|
||||||
|
|
||||||
|
|
||||||
def time_little_einsum_path():
|
|
||||||
subscripts = "ik,kj->ij"
|
|
||||||
x = mx.ones((32, 32))
|
|
||||||
y = mx.ones((32, 32))
|
|
||||||
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
|
||||||
|
|
||||||
x = np.array(x)
|
|
||||||
y = np.array(y)
|
|
||||||
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
|
||||||
print("Timing little einsum path...")
|
|
||||||
print(f"MLX ... {mx_time:.3f} ms")
|
|
||||||
print(f"NumPy... {np_time:.3f} ms")
|
|
||||||
|
|
||||||
|
|
||||||
def time_big_einsum_path():
|
|
||||||
chars = list("abcdefgh")
|
|
||||||
char_to_dim = {c: v for v, c in enumerate(chars)}
|
|
||||||
|
|
||||||
num_inputs = 10
|
|
||||||
inputs = []
|
|
||||||
subscripts = []
|
|
||||||
for _ in range(num_inputs):
|
|
||||||
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
|
||||||
subscripts.append("".join(subscript))
|
|
||||||
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
|
||||||
subscripts = ",".join(subscripts)
|
|
||||||
|
|
||||||
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
|
||||||
|
|
||||||
inputs = [mx.array(x) for x in inputs]
|
|
||||||
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
|
||||||
print("Timing big einsum path...")
|
|
||||||
print(f"MLX ... {mx_time:.3f} ms")
|
|
||||||
print(f"NumPy... {np_time:.3f} ms")
|
|
||||||
|
|
||||||
|
|
||||||
def time_attention():
|
|
||||||
def regular_attention(x):
|
|
||||||
# shape [batch, sequence, num_heads, head_dim]
|
|
||||||
queries, keys, values = x, x, x
|
|
||||||
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
|
||||||
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
|
||||||
mx.eval(output)
|
|
||||||
|
|
||||||
def einsum_attention(x):
|
|
||||||
# shape [batch, sequence, num_heads, head_dim]
|
|
||||||
queries, keys, values = x, x, x
|
|
||||||
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
|
||||||
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
|
||||||
mx.eval(output)
|
|
||||||
|
|
||||||
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
|
||||||
|
|
||||||
regular_time = timeit(regular_attention, args=(x,))
|
|
||||||
ein_time = timeit(einsum_attention, args=(x,))
|
|
||||||
print("Timing einsum attention...")
|
|
||||||
print(f"Regular ... {regular_time:.3f} ms")
|
|
||||||
print(f"Einsum ... {ein_time:.3f} ms")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
time_little_einsum_path()
|
|
||||||
time_big_einsum_path()
|
|
||||||
time_attention()
|
|
@@ -1,70 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
import matplotlib
|
|
||||||
import mlx.core as mx
|
|
||||||
import numpy as np
|
|
||||||
from time_utils import measure_runtime
|
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
def had(x):
|
|
||||||
y = mx.hadamard_transform(x)
|
|
||||||
mx.eval(y)
|
|
||||||
|
|
||||||
|
|
||||||
def copy(x):
|
|
||||||
y = x + 1.0
|
|
||||||
mx.eval(y)
|
|
||||||
|
|
||||||
|
|
||||||
def run(dtype):
|
|
||||||
system_size = 2**26
|
|
||||||
outputs = {}
|
|
||||||
for test_fn in (had, copy):
|
|
||||||
for m in [1, 12, 20, 28]:
|
|
||||||
if test_fn == copy:
|
|
||||||
key = "copy"
|
|
||||||
elif m == 1:
|
|
||||||
key = "had_2^k"
|
|
||||||
else:
|
|
||||||
key = "had_m*2^k"
|
|
||||||
outputs.setdefault(key, {})
|
|
||||||
for k in range(7, 14):
|
|
||||||
n = m * 2**k
|
|
||||||
if n > 2**15:
|
|
||||||
continue
|
|
||||||
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
|
|
||||||
x = mx.array(x_np)
|
|
||||||
runtime_ms = measure_runtime(test_fn, x=x)
|
|
||||||
bytes_per_gb = 1e9
|
|
||||||
ms_per_s = 1e3
|
|
||||||
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
|
|
||||||
bandwidth_gb = (
|
|
||||||
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
|
|
||||||
)
|
|
||||||
print(n, bandwidth_gb)
|
|
||||||
outputs[key][n] = bandwidth_gb
|
|
||||||
|
|
||||||
colors = {
|
|
||||||
"copy": "black",
|
|
||||||
"had_2^k": "steelblue",
|
|
||||||
"had_m*2^k": "skyblue",
|
|
||||||
}
|
|
||||||
for key, output in outputs.items():
|
|
||||||
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
|
|
||||||
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
|
|
||||||
plt.xlabel("N")
|
|
||||||
plt.ylabel("Bandwidth (GB/s)")
|
|
||||||
plt.legend()
|
|
||||||
plt.savefig(f"bench_{dtype.__name__}.png")
|
|
||||||
plt.clf()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--fp16", action="store_true")
|
|
||||||
args = parser.parse_args()
|
|
||||||
dtype = np.float16 if args.fp16 else np.float32
|
|
||||||
run(dtype)
|
|
@@ -9,7 +9,7 @@ from time_utils import measure_runtime
|
|||||||
|
|
||||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
||||||
def scatter(dst, x, idx):
|
def scatter(dst, x, idx):
|
||||||
dst[tuple(idx)] = x
|
dst[*idx] = x
|
||||||
mx.eval(dst)
|
mx.eval(dst)
|
||||||
|
|
||||||
idx = []
|
idx = []
|
||||||
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
|
|||||||
|
|
||||||
|
|
||||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
||||||
def scatter(dst, x, idx, device):
|
def gather(dst, x, idx, device):
|
||||||
dst[tuple(idx)] = x
|
dst[*idx] = x
|
||||||
if device == torch.device("mps"):
|
if device == torch.device("mps"):
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
|
||||||
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
|
|||||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||||
|
|
||||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
|
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||||
print(f"PyTorch: {runtime:.3f}ms")
|
print(f"PyTorch: {runtime:.3f}ms")
|
||||||
|
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ if __name__ == "__main__":
|
|||||||
(100_000, 64),
|
(100_000, 64),
|
||||||
(1_000_000, 64),
|
(1_000_000, 64),
|
||||||
(100_000,),
|
(100_000,),
|
||||||
(200_000,),
|
(2_000_00,),
|
||||||
(20_000_000,),
|
(20_000_000,),
|
||||||
(10000, 64),
|
(10000, 64),
|
||||||
(100, 64),
|
(100, 64),
|
||||||
@@ -91,6 +91,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||||
print("=" * 20)
|
print("=" * 20)
|
||||||
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
|
print(f"X {x_shape}, Indices {idx_shape}")
|
||||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
||||||
|
@@ -1,49 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import math
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from time_utils import time_fn
|
|
||||||
|
|
||||||
L = 1024
|
|
||||||
H = 32
|
|
||||||
H_k = 32 // 4
|
|
||||||
D = 128
|
|
||||||
|
|
||||||
|
|
||||||
def attention(q, k, v):
|
|
||||||
B, Hq, L, D = q.shape
|
|
||||||
_, Hk, S, _ = k.shape
|
|
||||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
|
||||||
k = k[:, :, None, :, :]
|
|
||||||
v = v[:, :, None, :, :]
|
|
||||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
|
||||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
|
||||||
o = p @ v
|
|
||||||
return o.reshape(B, Hq, L, D)
|
|
||||||
|
|
||||||
|
|
||||||
def sdpa(q, k, v):
|
|
||||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_primitives():
|
|
||||||
mx.random.seed(3)
|
|
||||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
|
||||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
|
||||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
|
||||||
mx.eval(q, k, v)
|
|
||||||
time_fn(attention, q, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
def time_self_attention_sdpa():
|
|
||||||
mx.random.seed(3)
|
|
||||||
q = mx.random.uniform(shape=(1, H, 1, D))
|
|
||||||
k = mx.random.uniform(shape=(1, H_k, L, D))
|
|
||||||
v = mx.random.uniform(shape=(1, H_k, L, D))
|
|
||||||
mx.eval(q, k, v)
|
|
||||||
time_fn(sdpa, q, k, v)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
time_self_attention_sdpa()
|
|
||||||
time_self_attention_primitives()
|
|
@@ -1,41 +1,56 @@
|
|||||||
include(CMakeParseArguments)
|
include(CMakeParseArguments)
|
||||||
|
|
||||||
# ##############################################################################
|
###############################################################################
|
||||||
# Build metal library
|
# Build metal library
|
||||||
#
|
#
|
||||||
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
||||||
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
||||||
#
|
#
|
||||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
# Args:
|
||||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
# TARGET: Custom target to be added for the metal library
|
||||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
# TITLE: Name of the .metallib
|
||||||
# files (like headers)
|
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
||||||
|
# SOURCES: List of source files
|
||||||
|
# INCLUDE_DIRS: List of include dirs
|
||||||
|
# DEPS: List of dependency files (like headers)
|
||||||
#
|
#
|
||||||
macro(mlx_build_metallib)
|
macro(mlx_build_metallib)
|
||||||
# Parse args
|
# Parse args
|
||||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
cmake_parse_arguments(
|
||||||
|
MTLLIB
|
||||||
|
""
|
||||||
|
"${oneValueArgs}"
|
||||||
|
"${multiValueArgs}"
|
||||||
|
${ARGN}
|
||||||
|
)
|
||||||
|
|
||||||
# Set output
|
# Set output
|
||||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||||
|
|
||||||
# Collect compile options
|
# Collect compile options
|
||||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||||
|
|
||||||
# Prepare metallib build command
|
# Prepare metallib build command
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||||
COMMAND
|
COMMAND xcrun -sdk macosx metal
|
||||||
xcrun -sdk macosx metal
|
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
${MTLLIB_COMPILE_OPTIONS}
|
||||||
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
${MTLLIB_SOURCES}
|
||||||
|
-o ${MTLLIB_BUILD_TARGET}
|
||||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||||
COMMAND_EXPAND_LISTS
|
COMMAND_EXPAND_LISTS
|
||||||
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
||||||
VERBATIM)
|
VERBATIM
|
||||||
|
)
|
||||||
|
|
||||||
# Add metallib custom target
|
# Add metallib custom target
|
||||||
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
|
add_custom_target(
|
||||||
|
${MTLLIB_TARGET}
|
||||||
|
DEPENDS
|
||||||
|
${MTLLIB_BUILD_TARGET}
|
||||||
|
)
|
||||||
|
|
||||||
endmacro(mlx_build_metallib)
|
endmacro(mlx_build_metallib)
|
@@ -1,4 +1,3 @@
|
|||||||
sphinx
|
sphinx
|
||||||
breathe
|
breathe
|
||||||
sphinx-book-theme
|
sphinx-book-theme
|
||||||
mlx
|
|
||||||
|
@@ -60,7 +60,6 @@ html_theme_options = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
html_favicon = html_theme_options["logo"]["image_light"]
|
|
||||||
|
|
||||||
# -- Options for HTMLHelp output ---------------------------------------------
|
# -- Options for HTMLHelp output ---------------------------------------------
|
||||||
|
|
||||||
@@ -84,15 +83,3 @@ def setup(app):
|
|||||||
# -- Options for LaTeX output ------------------------------------------------
|
# -- Options for LaTeX output ------------------------------------------------
|
||||||
|
|
||||||
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
||||||
latex_elements = {
|
|
||||||
"preamble": r"""
|
|
||||||
\usepackage{enumitem}
|
|
||||||
\setlistdepth{5}
|
|
||||||
\setlist[itemize,1]{label=$\bullet$}
|
|
||||||
\setlist[itemize,2]{label=$\bullet$}
|
|
||||||
\setlist[itemize,3]{label=$\bullet$}
|
|
||||||
\setlist[itemize,4]{label=$\bullet$}
|
|
||||||
\setlist[itemize,5]{label=$\bullet$}
|
|
||||||
\renewlist{itemize}{itemize}{5}
|
|
||||||
""",
|
|
||||||
}
|
|
||||||
|
@@ -1,427 +0,0 @@
|
|||||||
.. _custom_metal_kernels:
|
|
||||||
|
|
||||||
Custom Metal Kernels
|
|
||||||
====================
|
|
||||||
|
|
||||||
MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
|
||||||
|
|
||||||
Simple Example
|
|
||||||
--------------
|
|
||||||
|
|
||||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
T tmp = inp[elem];
|
|
||||||
out[elem] = metal::exp(tmp);
|
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="myexp",
|
|
||||||
input_names=["inp"],
|
|
||||||
output_names=["out"],
|
|
||||||
source=source,
|
|
||||||
)
|
|
||||||
outputs = kernel(
|
|
||||||
inputs=[a],
|
|
||||||
template=[("T", mx.float32)],
|
|
||||||
grid=(a.size, 1, 1),
|
|
||||||
threadgroup=(256, 1, 1),
|
|
||||||
output_shapes=[a.shape],
|
|
||||||
output_dtypes=[a.dtype],
|
|
||||||
)
|
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
|
||||||
b = exp_elementwise(a)
|
|
||||||
assert mx.allclose(b, mx.exp(a))
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
We are only required to pass the body of the Metal kernel in ``source``.
|
|
||||||
|
|
||||||
The full function signature will be generated using:
|
|
||||||
|
|
||||||
* The shapes/dtypes of ``inputs``
|
|
||||||
In the above, ``a`` is an ``mx.array`` of type ``mx.float16`` and we pass it with the key ``inp``
|
|
||||||
so we will add ``const device float16_t* inp`` to the signature.
|
|
||||||
``inp_shape``, ``inp_strides`` and ``inp_ndim`` are also added for convenience if they are present
|
|
||||||
in ``source``.
|
|
||||||
* The list of ``output_dtypes``
|
|
||||||
In the above, ``out`` is an ``mx.array`` of type ``mx.float16``
|
|
||||||
so we add ``device float16_t* out``.
|
|
||||||
* Template parameters passed using ``template``
|
|
||||||
In the above, ``template=[("T", mx.float32)]`` adds a template of ``template <typename T>`` to the function
|
|
||||||
and instantiates the template with ``custom_kernel_myexp_float<float>``.
|
|
||||||
Template parameters can be ``mx.core.Dtype``, ``int`` or ``bool``.
|
|
||||||
* Metal attributes used in ``source`` such as ``[[thread_position_in_grid]]``
|
|
||||||
These will be added as function arguments.
|
|
||||||
All the attributes defined in Table 5.8 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ are supported.
|
|
||||||
|
|
||||||
Putting this all together, the generated function signature for ``myexp`` is as follows:
|
|
||||||
|
|
||||||
.. code-block:: cpp
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
[[kernel]] void custom_kernel_myexp_float(
|
|
||||||
const device float16_t* inp [[buffer(0)]],
|
|
||||||
device float16_t* out [[buffer(1)]],
|
|
||||||
uint3 thread_position_in_grid [[thread_position_in_grid]]) {
|
|
||||||
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
T tmp = inp[elem];
|
|
||||||
out[elem] = metal::exp(tmp);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
|
||||||
|
|
||||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
|
||||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
|
||||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
|
||||||
|
|
||||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
|
||||||
|
|
||||||
Using Shape/Strides
|
|
||||||
-------------------
|
|
||||||
|
|
||||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
|
||||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
|
||||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
|
||||||
when indexing.
|
|
||||||
|
|
||||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
|
||||||
input array ``a`` if any are present in ``source``.
|
|
||||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
|
||||||
|
|
||||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
def exp_elementwise(a: mx.array):
|
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
|
||||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
|
||||||
T tmp = inp[loc];
|
|
||||||
// Output arrays are always row contiguous
|
|
||||||
out[elem] = metal::exp(tmp);
|
|
||||||
"""
|
|
||||||
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="myexp_strided",
|
|
||||||
input_names=["inp"],
|
|
||||||
output_names=["out"],
|
|
||||||
source=source
|
|
||||||
)
|
|
||||||
outputs = kernel(
|
|
||||||
inputs=[a],
|
|
||||||
template=[("T", mx.float32)],
|
|
||||||
grid=(a.size, 1, 1),
|
|
||||||
threadgroup=(256, 1, 1),
|
|
||||||
output_shapes=[a.shape],
|
|
||||||
output_dtypes=[a.dtype],
|
|
||||||
ensure_row_contiguous=False,
|
|
||||||
)
|
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
|
||||||
# make non-contiguous
|
|
||||||
a = a[::2]
|
|
||||||
b = exp_elementwise(a)
|
|
||||||
assert mx.allclose(b, mx.exp(a))
|
|
||||||
|
|
||||||
Complex Example
|
|
||||||
-----------------------------
|
|
||||||
|
|
||||||
Let's implement a more complex example: ``grid_sample`` in ``"bilinear"`` mode.
|
|
||||||
|
|
||||||
We'll start with the following MLX implementation using standard ops:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
def grid_sample_ref(x, grid):
|
|
||||||
N, H_in, W_in, _ = x.shape
|
|
||||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
|
||||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
|
||||||
|
|
||||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
|
||||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
|
||||||
|
|
||||||
ix_ne = ix_nw + 1
|
|
||||||
iy_ne = iy_nw
|
|
||||||
|
|
||||||
ix_sw = ix_nw
|
|
||||||
iy_sw = iy_nw + 1
|
|
||||||
|
|
||||||
ix_se = ix_nw + 1
|
|
||||||
iy_se = iy_nw + 1
|
|
||||||
|
|
||||||
nw = (ix_se - ix) * (iy_se - iy)
|
|
||||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
|
||||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
|
||||||
se = (ix - ix_nw) * (iy - iy_nw)
|
|
||||||
|
|
||||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
|
||||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
|
||||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
|
||||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
|
||||||
|
|
||||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
|
||||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
|
||||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
|
||||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
|
||||||
|
|
||||||
I_nw *= mask_nw[..., None]
|
|
||||||
I_ne *= mask_ne[..., None]
|
|
||||||
I_sw *= mask_sw[..., None]
|
|
||||||
I_se *= mask_se[..., None]
|
|
||||||
|
|
||||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
|
||||||
to write a fast GPU kernel for both the forward and backward passes.
|
|
||||||
|
|
||||||
First we'll implement the forward pass as a fused kernel:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
@mx.custom_function
|
|
||||||
def grid_sample(x, grid):
|
|
||||||
|
|
||||||
assert x.ndim == 4, "`x` must be 4D."
|
|
||||||
assert grid.ndim == 4, "`grid` must be 4D."
|
|
||||||
|
|
||||||
B, _, _, C = x.shape
|
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
out_shape = (B, gN, gM, C)
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
|
||||||
|
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
int H = x_shape[1];
|
|
||||||
int W = x_shape[2];
|
|
||||||
int C = x_shape[3];
|
|
||||||
int gH = grid_shape[1];
|
|
||||||
int gW = grid_shape[2];
|
|
||||||
|
|
||||||
int w_stride = C;
|
|
||||||
int h_stride = W * w_stride;
|
|
||||||
int b_stride = H * h_stride;
|
|
||||||
|
|
||||||
uint grid_idx = elem / C * 2;
|
|
||||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
|
||||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
|
||||||
|
|
||||||
int ix_nw = floor(ix);
|
|
||||||
int iy_nw = floor(iy);
|
|
||||||
|
|
||||||
int ix_ne = ix_nw + 1;
|
|
||||||
int iy_ne = iy_nw;
|
|
||||||
|
|
||||||
int ix_sw = ix_nw;
|
|
||||||
int iy_sw = iy_nw + 1;
|
|
||||||
|
|
||||||
int ix_se = ix_nw + 1;
|
|
||||||
int iy_se = iy_nw + 1;
|
|
||||||
|
|
||||||
T nw = (ix_se - ix) * (iy_se - iy);
|
|
||||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
|
||||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
|
||||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
|
||||||
|
|
||||||
int batch_idx = elem / C / gH / gW * b_stride;
|
|
||||||
int channel_idx = elem % C;
|
|
||||||
int base_idx = batch_idx + channel_idx;
|
|
||||||
|
|
||||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
|
||||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
|
||||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
|
||||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
|
||||||
|
|
||||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
|
||||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
|
||||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
|
||||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
|
||||||
|
|
||||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
|
||||||
"""
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="grid_sample",
|
|
||||||
input_names=["x", "grid"],
|
|
||||||
output_names=["out"],
|
|
||||||
source=source,
|
|
||||||
)
|
|
||||||
outputs = kernel(
|
|
||||||
inputs=[x, grid],
|
|
||||||
template=[("T", x.dtype)],
|
|
||||||
output_shapes=[out_shape],
|
|
||||||
output_dtypes=[x.dtype],
|
|
||||||
grid=(np.prod(out_shape), 1, 1),
|
|
||||||
threadgroup=(256, 1, 1),
|
|
||||||
)
|
|
||||||
return outputs[0]
|
|
||||||
|
|
||||||
For a reasonably sized input such as:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
x.shape = (8, 1024, 1024, 64)
|
|
||||||
grid.shape = (8, 256, 256, 2)
|
|
||||||
|
|
||||||
On an M1 Max, we see a big performance improvement:
|
|
||||||
|
|
||||||
``55.7ms -> 6.7ms => 8x speed up``
|
|
||||||
|
|
||||||
Grid Sample VJP
|
|
||||||
---------------
|
|
||||||
|
|
||||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
|
||||||
its custom vjp transform so MLX can differentiate it.
|
|
||||||
|
|
||||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
|
||||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
|
||||||
|
|
||||||
* ``init_value=0``
|
|
||||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
|
||||||
|
|
||||||
* ``atomic_outputs=True``
|
|
||||||
Designate all of the kernel outputs as ``atomic`` in the function signature.
|
|
||||||
This means we can use Metal's ``atomic`` features to simultaneously update the ``x_grad`` and ``grid_grad`` arrays from multiple threadgroups.
|
|
||||||
See section 6.15 of the `Metal Shading Language Specification <https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf>`_ for more details.
|
|
||||||
|
|
||||||
We can then implement the backwards pass as follows:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
@grid_sample.vjp
|
|
||||||
def grid_sample_vjp(primals, cotangent, _):
|
|
||||||
x, grid = primals
|
|
||||||
B, _, _, C = x.shape
|
|
||||||
_, gN, gM, D = grid.shape
|
|
||||||
|
|
||||||
assert D == 2, "Last dim of `grid` must be size 2."
|
|
||||||
|
|
||||||
source = """
|
|
||||||
uint elem = thread_position_in_grid.x;
|
|
||||||
int H = x_shape[1];
|
|
||||||
int W = x_shape[2];
|
|
||||||
int C = x_shape[3];
|
|
||||||
// Pad C to the nearest larger simdgroup size multiple
|
|
||||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
|
||||||
|
|
||||||
int gH = grid_shape[1];
|
|
||||||
int gW = grid_shape[2];
|
|
||||||
|
|
||||||
int w_stride = C;
|
|
||||||
int h_stride = W * w_stride;
|
|
||||||
int b_stride = H * h_stride;
|
|
||||||
|
|
||||||
uint grid_idx = elem / C_padded * 2;
|
|
||||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
|
||||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
|
||||||
|
|
||||||
int ix_nw = floor(ix);
|
|
||||||
int iy_nw = floor(iy);
|
|
||||||
|
|
||||||
int ix_ne = ix_nw + 1;
|
|
||||||
int iy_ne = iy_nw;
|
|
||||||
|
|
||||||
int ix_sw = ix_nw;
|
|
||||||
int iy_sw = iy_nw + 1;
|
|
||||||
|
|
||||||
int ix_se = ix_nw + 1;
|
|
||||||
int iy_se = iy_nw + 1;
|
|
||||||
|
|
||||||
T nw = (ix_se - ix) * (iy_se - iy);
|
|
||||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
|
||||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
|
||||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
|
||||||
|
|
||||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
|
||||||
int channel_idx = elem % C_padded;
|
|
||||||
int base_idx = batch_idx + channel_idx;
|
|
||||||
|
|
||||||
T gix = T(0);
|
|
||||||
T giy = T(0);
|
|
||||||
if (channel_idx < C) {
|
|
||||||
int cot_index = elem / C_padded * C + channel_idx;
|
|
||||||
T cot = cotangent[cot_index];
|
|
||||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
|
||||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_nw = x[offset];
|
|
||||||
gix -= I_nw * (iy_se - iy) * cot;
|
|
||||||
giy -= I_nw * (ix_se - ix) * cot;
|
|
||||||
}
|
|
||||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
|
||||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_ne = x[offset];
|
|
||||||
gix += I_ne * (iy_sw - iy) * cot;
|
|
||||||
giy -= I_ne * (ix - ix_sw) * cot;
|
|
||||||
}
|
|
||||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
|
||||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_sw = x[offset];
|
|
||||||
gix -= I_sw * (iy - iy_ne) * cot;
|
|
||||||
giy += I_sw * (ix_ne - ix) * cot;
|
|
||||||
}
|
|
||||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
|
||||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
|
||||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
|
||||||
|
|
||||||
T I_se = x[offset];
|
|
||||||
gix += I_se * (iy - iy_nw) * cot;
|
|
||||||
giy += I_se * (ix - ix_nw) * cot;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
T gix_mult = W / 2;
|
|
||||||
T giy_mult = H / 2;
|
|
||||||
|
|
||||||
// Reduce across each simdgroup first.
|
|
||||||
// This is much faster than relying purely on atomics.
|
|
||||||
gix = simd_sum(gix);
|
|
||||||
giy = simd_sum(giy);
|
|
||||||
|
|
||||||
if (thread_index_in_simdgroup == 0) {
|
|
||||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
|
||||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
kernel = mx.fast.metal_kernel(
|
|
||||||
name="grid_sample_grad",
|
|
||||||
input_names=["x", "grid", "cotangent"],
|
|
||||||
output_names=["x_grad", "grid_grad"],
|
|
||||||
source=source,
|
|
||||||
atomic_outputs=True,
|
|
||||||
)
|
|
||||||
# pad the output channels to simd group size
|
|
||||||
# so that our `simd_sum`s don't overlap.
|
|
||||||
simdgroup_size = 32
|
|
||||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
|
||||||
grid_size = B * gN * gM * C_padded
|
|
||||||
outputs = kernel(
|
|
||||||
inputs=[x, grid, cotangent],
|
|
||||||
template=[("T", x.dtype)],
|
|
||||||
output_shapes=[x.shape, grid.shape],
|
|
||||||
output_dtypes=[x.dtype, x.dtype],
|
|
||||||
grid=(grid_size, 1, 1),
|
|
||||||
threadgroup=(256, 1, 1),
|
|
||||||
init_value=0,
|
|
||||||
)
|
|
||||||
return outputs[0], outputs[1]
|
|
||||||
|
|
||||||
There's an even larger speed up for the vjp:
|
|
||||||
|
|
||||||
``676.4ms -> 16.7ms => 40x speed up``
|
|
@@ -486,8 +486,9 @@ below.
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "axpby_" << "general_" << type_to_name(out);
|
kname << "axpby_" << "general_" << type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Make sure the metal library is available and look for it
|
||||||
d.register_library("mlx_ext");
|
// in the same folder as this executable if needed
|
||||||
|
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
@@ -15,7 +15,7 @@ module to concisely define the model architecture.
|
|||||||
Attention layer
|
Attention layer
|
||||||
^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
We will start with the Llama attention layer which notably uses the RoPE
|
We will start with the llama attention layer which notably uses the RoPE
|
||||||
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||||
key/value cache that will be concatenated with the provided keys and values to
|
key/value cache that will be concatenated with the provided keys and values to
|
||||||
support efficient inference.
|
support efficient inference.
|
||||||
|
@@ -64,7 +64,7 @@ set:
|
|||||||
Next, setup the problem parameters and load the data. To load the data, you need our
|
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||||
`mnist data loader
|
`mnist data loader
|
||||||
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||||
we will import as ``mnist``.
|
we will import as `mnist`.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@@ -85,4 +85,3 @@ are the CPU and GPU.
|
|||||||
|
|
||||||
dev/extensions
|
dev/extensions
|
||||||
dev/metal_debugger
|
dev/metal_debugger
|
||||||
dev/custom_metal_kernels
|
|
||||||
|
@@ -14,7 +14,7 @@ silicon computer is
|
|||||||
To install from PyPI you must meet the following requirements:
|
To install from PyPI you must meet the following requirements:
|
||||||
|
|
||||||
- Using an M series chip (Apple silicon)
|
- Using an M series chip (Apple silicon)
|
||||||
- Using a native Python >= 3.9
|
- Using a native Python >= 3.8
|
||||||
- macOS >= 13.5
|
- macOS >= 13.5
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
|
|||||||
|
|
||||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||||
|
|
||||||
|
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||||
|
|
||||||
Then simply build and install MLX using pip:
|
Then simply build and install MLX using pip:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||||
|
|
||||||
For developing, install the package with development dependencies, and use an
|
For developing use an editable install:
|
||||||
editable install:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
|
||||||
|
|
||||||
Once the development dependencies are installed, you can build faster with:
|
To make sure the install is working run the tests with:
|
||||||
|
|
||||||
.. code-block:: shell
|
|
||||||
|
|
||||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
|
||||||
|
|
||||||
Run the tests with:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install ".[testing]"
|
||||||
python -m unittest discover python/tests
|
python -m unittest discover python/tests
|
||||||
|
|
||||||
Optional: Install stubs to enable auto completions and type checking from your
|
Optional: Install stubs to enable auto completions and type checking from your IDE:
|
||||||
IDE:
|
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
|
pip install ".[dev]"
|
||||||
python setup.py generate_stubs
|
python setup.py generate_stubs
|
||||||
|
|
||||||
C++ API
|
C++ API
|
||||||
@@ -195,7 +195,7 @@ GGUF, you can do:
|
|||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
cmake .. \
|
cmake ..
|
||||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||||
-DBUILD_SHARED_LIBS=ON \
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
-DMLX_BUILD_CPU=OFF \
|
-DMLX_BUILD_CPU=OFF \
|
||||||
@@ -240,7 +240,7 @@ x86 Shell
|
|||||||
|
|
||||||
.. _build shell:
|
.. _build shell:
|
||||||
|
|
||||||
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
|
||||||
Rosetta instead of natively.
|
Rosetta instead of natively.
|
||||||
|
|
||||||
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
To fix this, find the application in Finder (``/Applications`` for iTerm,
|
||||||
@@ -264,4 +264,4 @@ Also check that cmake is using the correct architecture:
|
|||||||
|
|
||||||
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
|
||||||
but the build errors out with "Building for x86_64 on macOS is not supported."
|
but the build errors out with "Building for x86_64 on macOS is not supported."
|
||||||
wipe your build cache with ``rm -rf build/`` and try again.
|
wipe your build cahce with ``rm -rf build/`` and try again.
|
||||||
|
@@ -24,7 +24,6 @@ Array
|
|||||||
array.any
|
array.any
|
||||||
array.argmax
|
array.argmax
|
||||||
array.argmin
|
array.argmin
|
||||||
array.conj
|
|
||||||
array.cos
|
array.cos
|
||||||
array.cummax
|
array.cummax
|
||||||
array.cummin
|
array.cummin
|
||||||
@@ -53,10 +52,8 @@ Array
|
|||||||
array.sqrt
|
array.sqrt
|
||||||
array.square
|
array.square
|
||||||
array.squeeze
|
array.squeeze
|
||||||
array.std
|
|
||||||
array.sum
|
|
||||||
array.swapaxes
|
array.swapaxes
|
||||||
|
array.sum
|
||||||
array.transpose
|
array.transpose
|
||||||
array.T
|
array.T
|
||||||
array.var
|
array.var
|
||||||
array.view
|
|
||||||
|
@@ -17,6 +17,3 @@ made available.
|
|||||||
init
|
init
|
||||||
all_sum
|
all_sum
|
||||||
all_gather
|
all_gather
|
||||||
send
|
|
||||||
recv
|
|
||||||
recv_like
|
|
||||||
|
@@ -12,5 +12,3 @@ Fast
|
|||||||
layer_norm
|
layer_norm
|
||||||
rope
|
rope
|
||||||
scaled_dot_product_attention
|
scaled_dot_product_attention
|
||||||
affine_quantize
|
|
||||||
metal_kernel
|
|
||||||
|
@@ -9,12 +9,7 @@ Linear Algebra
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
inv
|
inv
|
||||||
tri_inv
|
|
||||||
norm
|
norm
|
||||||
cholesky
|
cholesky
|
||||||
cholesky_inv
|
|
||||||
cross
|
|
||||||
qr
|
qr
|
||||||
svd
|
svd
|
||||||
eigvalsh
|
|
||||||
eigh
|
|
||||||
|
@@ -14,7 +14,6 @@ Metal
|
|||||||
get_cache_memory
|
get_cache_memory
|
||||||
set_memory_limit
|
set_memory_limit
|
||||||
set_cache_limit
|
set_cache_limit
|
||||||
set_wired_limit
|
|
||||||
clear_cache
|
clear_cache
|
||||||
start_capture
|
start_capture
|
||||||
stop_capture
|
stop_capture
|
||||||
|
@@ -13,7 +13,6 @@ simple functions.
|
|||||||
:template: nn-module-template.rst
|
:template: nn-module-template.rst
|
||||||
|
|
||||||
elu
|
elu
|
||||||
celu
|
|
||||||
gelu
|
gelu
|
||||||
gelu_approx
|
gelu_approx
|
||||||
gelu_fast_approx
|
gelu_fast_approx
|
||||||
|
@@ -13,18 +13,13 @@ Layers
|
|||||||
AvgPool1d
|
AvgPool1d
|
||||||
AvgPool2d
|
AvgPool2d
|
||||||
BatchNorm
|
BatchNorm
|
||||||
CELU
|
|
||||||
Conv1d
|
Conv1d
|
||||||
Conv2d
|
Conv2d
|
||||||
Conv3d
|
Conv3d
|
||||||
ConvTranspose1d
|
|
||||||
ConvTranspose2d
|
|
||||||
ConvTranspose3d
|
|
||||||
Dropout
|
Dropout
|
||||||
Dropout2d
|
Dropout2d
|
||||||
Dropout3d
|
Dropout3d
|
||||||
Embedding
|
Embedding
|
||||||
ELU
|
|
||||||
GELU
|
GELU
|
||||||
GLU
|
GLU
|
||||||
GroupNorm
|
GroupNorm
|
||||||
@@ -36,8 +31,6 @@ Layers
|
|||||||
LayerNorm
|
LayerNorm
|
||||||
LeakyReLU
|
LeakyReLU
|
||||||
Linear
|
Linear
|
||||||
LogSigmoid
|
|
||||||
LogSoftmax
|
|
||||||
LSTM
|
LSTM
|
||||||
MaxPool1d
|
MaxPool1d
|
||||||
MaxPool2d
|
MaxPool2d
|
||||||
@@ -53,7 +46,6 @@ Layers
|
|||||||
RoPE
|
RoPE
|
||||||
SELU
|
SELU
|
||||||
Sequential
|
Sequential
|
||||||
Sigmoid
|
|
||||||
SiLU
|
SiLU
|
||||||
SinusoidalPositionalEncoding
|
SinusoidalPositionalEncoding
|
||||||
Softmin
|
Softmin
|
||||||
|
@@ -44,10 +44,6 @@ Operations
|
|||||||
convolve
|
convolve
|
||||||
conv1d
|
conv1d
|
||||||
conv2d
|
conv2d
|
||||||
conv3d
|
|
||||||
conv_transpose1d
|
|
||||||
conv_transpose2d
|
|
||||||
conv_transpose3d
|
|
||||||
conv_general
|
conv_general
|
||||||
cos
|
cos
|
||||||
cosh
|
cosh
|
||||||
@@ -61,8 +57,6 @@ Operations
|
|||||||
diagonal
|
diagonal
|
||||||
divide
|
divide
|
||||||
divmod
|
divmod
|
||||||
einsum
|
|
||||||
einsum_path
|
|
||||||
equal
|
equal
|
||||||
erf
|
erf
|
||||||
erfinv
|
erfinv
|
||||||
@@ -78,11 +72,8 @@ Operations
|
|||||||
gather_qmm
|
gather_qmm
|
||||||
greater
|
greater
|
||||||
greater_equal
|
greater_equal
|
||||||
hadamard_transform
|
|
||||||
identity
|
identity
|
||||||
imag
|
|
||||||
inner
|
inner
|
||||||
isfinite
|
|
||||||
isclose
|
isclose
|
||||||
isinf
|
isinf
|
||||||
isnan
|
isnan
|
||||||
@@ -112,7 +103,6 @@ Operations
|
|||||||
minimum
|
minimum
|
||||||
moveaxis
|
moveaxis
|
||||||
multiply
|
multiply
|
||||||
nan_to_num
|
|
||||||
negative
|
negative
|
||||||
not_equal
|
not_equal
|
||||||
ones
|
ones
|
||||||
@@ -122,17 +112,14 @@ Operations
|
|||||||
pad
|
pad
|
||||||
power
|
power
|
||||||
prod
|
prod
|
||||||
put_along_axis
|
|
||||||
quantize
|
quantize
|
||||||
quantized_matmul
|
quantized_matmul
|
||||||
radians
|
radians
|
||||||
real
|
|
||||||
reciprocal
|
reciprocal
|
||||||
remainder
|
remainder
|
||||||
repeat
|
repeat
|
||||||
reshape
|
reshape
|
||||||
right_shift
|
right_shift
|
||||||
roll
|
|
||||||
round
|
round
|
||||||
rsqrt
|
rsqrt
|
||||||
save
|
save
|
||||||
|
@@ -31,41 +31,6 @@ model's parameters and the **optimizer state**.
|
|||||||
# Compute the new parameters but also the optimizer state.
|
# Compute the new parameters but also the optimizer state.
|
||||||
mx.eval(model.parameters(), optimizer.state)
|
mx.eval(model.parameters(), optimizer.state)
|
||||||
|
|
||||||
Saving and Loading
|
|
||||||
------------------
|
|
||||||
|
|
||||||
To serialize an optimizer, save its state. To load an optimizer, load and set
|
|
||||||
the saved state. Here's a simple example:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
import mlx.core as mx
|
|
||||||
from mlx.utils import tree_flatten, tree_unflatten
|
|
||||||
import mlx.optimizers as optim
|
|
||||||
|
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
|
||||||
|
|
||||||
# Perform some updates with the optimizer
|
|
||||||
model = {"w" : mx.zeros((5, 5))}
|
|
||||||
grads = {"w" : mx.ones((5, 5))}
|
|
||||||
optimizer.update(model, grads)
|
|
||||||
|
|
||||||
# Save the state
|
|
||||||
state = tree_flatten(optimizer.state)
|
|
||||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
|
||||||
|
|
||||||
# Later on, for example when loading from a checkpoint,
|
|
||||||
# recreate the optimizer and load the state
|
|
||||||
optimizer = optim.Adam(learning_rate=1e-2)
|
|
||||||
|
|
||||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
|
||||||
optimizer.state = state
|
|
||||||
|
|
||||||
Note, not every optimizer configuation parameter is saved in the state. For
|
|
||||||
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
|
|
||||||
parameters are not. A good rule of thumb is if the parameter can be scheduled
|
|
||||||
then it will be included in the optimizer state.
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|
||||||
optimizers/optimizer
|
optimizers/optimizer
|
||||||
|
@@ -44,5 +44,3 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
|||||||
split
|
split
|
||||||
truncated_normal
|
truncated_normal
|
||||||
uniform
|
uniform
|
||||||
laplace
|
|
||||||
permutation
|
|
||||||
|
@@ -10,7 +10,6 @@ Transforms
|
|||||||
|
|
||||||
eval
|
eval
|
||||||
compile
|
compile
|
||||||
custom_function
|
|
||||||
disable_compile
|
disable_compile
|
||||||
enable_compile
|
enable_compile
|
||||||
grad
|
grad
|
||||||
|
@@ -33,12 +33,12 @@ Let's start with a simple example:
|
|||||||
# Compile the function
|
# Compile the function
|
||||||
compiled_fun = mx.compile(fun)
|
compiled_fun = mx.compile(fun)
|
||||||
|
|
||||||
# Prints: array(2.36788, dtype=float32)
|
# Prints: array(2.36788, dtype=float32)
|
||||||
print(compiled_fun(x, y))
|
print(compiled_fun(x, y))
|
||||||
|
|
||||||
The output of both the regular function and the compiled function is the same
|
The output of both the regular function and the compiled function is the same
|
||||||
up to numerical precision.
|
up to numerical precision.
|
||||||
|
|
||||||
The first time you call a compiled function, MLX will build the compute
|
The first time you call a compiled function, MLX will build the compute
|
||||||
graph, optimize it, and generate and compile code. This can be relatively
|
graph, optimize it, and generate and compile code. This can be relatively
|
||||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||||
@@ -96,7 +96,7 @@ element-wise operations:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||||
|
|
||||||
If you use this function with small arrays, it will be overhead bound. If you
|
If you use this function with small arrays, it will be overhead bound. If you
|
||||||
@@ -136,6 +136,13 @@ Now make an array, and benchmark both functions:
|
|||||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||||
five times faster.
|
five times faster.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
||||||
|
functions can still be helpful, but won't typically result in as large a
|
||||||
|
speedup as compiling operations that run on the GPU.
|
||||||
|
|
||||||
|
|
||||||
Debugging
|
Debugging
|
||||||
---------
|
---------
|
||||||
|
|
||||||
@@ -280,7 +287,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
|
|||||||
print(fun(mx.array(1.0)))
|
print(fun(mx.array(1.0)))
|
||||||
|
|
||||||
|
|
||||||
Compiling Training Graphs
|
Compiling Training Graphs
|
||||||
-------------------------
|
-------------------------
|
||||||
|
|
||||||
This section will step through how to use :func:`compile` with a simple example
|
This section will step through how to use :func:`compile` with a simple example
|
||||||
@@ -290,7 +297,7 @@ full forward, backward, and update with :func:`compile`.
|
|||||||
|
|
||||||
To start, here is the simple example without any compilation:
|
To start, here is the simple example without any compilation:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -323,7 +330,7 @@ To start, here is the simple example without any compilation:
|
|||||||
To compile the update we can put it all in a function and compile it with the
|
To compile the update we can put it all in a function and compile it with the
|
||||||
appropriate input and output captures. Here's the same example but compiled:
|
appropriate input and output captures. Here's the same example but compiled:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -348,7 +355,7 @@ appropriate input and output captures. Here's the same example but compiled:
|
|||||||
|
|
||||||
# The state that will be captured as input and output
|
# The state that will be captured as input and output
|
||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def step(x, y):
|
def step(x, y):
|
||||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
@@ -403,7 +410,7 @@ Compiling transformed functions works just as expected:
|
|||||||
|
|
||||||
In order to compile as much as possible, a transformation of a compiled
|
In order to compile as much as possible, a transformation of a compiled
|
||||||
function will not by default be compiled. To compile the transformed
|
function will not by default be compiled. To compile the transformed
|
||||||
function simply pass it through :func:`compile`.
|
function simply pass it through :func:`compile`.
|
||||||
|
|
||||||
You can also compile functions which themselves call compiled functions. A
|
You can also compile functions which themselves call compiled functions. A
|
||||||
good practice is to compile the outer most function to give :func:`compile`
|
good practice is to compile the outer most function to give :func:`compile`
|
||||||
|
@@ -25,7 +25,7 @@ Here is a simple example:
|
|||||||
|
|
||||||
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
||||||
case it is the gradient of the sine function which is exactly the cosine
|
case it is the gradient of the sine function which is exactly the cosine
|
||||||
function. To get the second derivative you can do:
|
function. To get the second derivative you can do:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ Automatic Differentiation
|
|||||||
.. _auto diff:
|
.. _auto diff:
|
||||||
|
|
||||||
Automatic differentiation in MLX works on functions rather than on implicit
|
Automatic differentiation in MLX works on functions rather than on implicit
|
||||||
graphs.
|
graphs.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@@ -114,7 +114,7 @@ way to do that is the following:
|
|||||||
|
|
||||||
def loss_fn(params, x, y):
|
def loss_fn(params, x, y):
|
||||||
w, b = params["weight"], params["bias"]
|
w, b = params["weight"], params["bias"]
|
||||||
h = w * x + b
|
h = w * x + b
|
||||||
return mx.mean(mx.square(h - y))
|
return mx.mean(mx.square(h - y))
|
||||||
|
|
||||||
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
||||||
@@ -132,7 +132,7 @@ way to do that is the following:
|
|||||||
|
|
||||||
Notice the tree structure of the parameters is preserved in the gradients.
|
Notice the tree structure of the parameters is preserved in the gradients.
|
||||||
|
|
||||||
In some cases you may want to stop gradients from propagating through a
|
In some cases you may want to stop gradients from propagating through a
|
||||||
part of the function. You can use the :func:`stop_gradient` for that.
|
part of the function. You can use the :func:`stop_gradient` for that.
|
||||||
|
|
||||||
|
|
||||||
@@ -161,19 +161,19 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
|||||||
ys = mx.random.uniform(shape=(100, 4096))
|
ys = mx.random.uniform(shape=(100, 4096))
|
||||||
|
|
||||||
def naive_add(xs, ys):
|
def naive_add(xs, ys):
|
||||||
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
||||||
|
|
||||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# Vectorize over the second dimension of x and the
|
# Vectorize over the second dimension of x and the
|
||||||
# first dimension of y
|
# first dimension of y
|
||||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||||
|
|
||||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||||
where the vectorized axes should be in the outputs.
|
where the vectorized axes should be in the outputs.
|
||||||
|
|
||||||
Let's time these two different versions:
|
Let's time these two different versions:
|
||||||
|
|
||||||
|
@@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
|
|||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
>>> arr = mx.arange(10)
|
>>> arr = mx.arange(10)
|
||||||
>>> idx = mx.array([5, 7])
|
>>> idx = mx.array([5, 7])
|
||||||
>>> arr[idx]
|
>>> arr[idx]
|
||||||
array([5, 7], dtype=int32)
|
array([5, 7], dtype=int32)
|
||||||
|
|
||||||
@@ -77,12 +77,12 @@ from the GPU. Performing bounds checking for array indices before launching the
|
|||||||
kernel would be extremely inefficient.
|
kernel would be extremely inefficient.
|
||||||
|
|
||||||
Indexing with boolean masks is something that MLX may support in the future. In
|
Indexing with boolean masks is something that MLX may support in the future. In
|
||||||
general, MLX has limited support for operations for which output
|
general, MLX has limited support for operations for which outputs
|
||||||
*shapes* are dependent on input *data*. Other examples of these types of
|
*shapes* are dependent on input *data*. Other examples of these types of
|
||||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||||
single input version of :func:`numpy.where`.
|
single input version of :func:`numpy.where`.
|
||||||
|
|
||||||
In Place Updates
|
In Place Updates
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
In place updates to indexed arrays are possible in MLX. For example:
|
In place updates to indexed arrays are possible in MLX. For example:
|
||||||
|
@@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an
|
|||||||
:func:`eval` is performed.
|
:func:`eval` is performed.
|
||||||
|
|
||||||
MLX uses lazy evaluation because it has some nice features, some of which we
|
MLX uses lazy evaluation because it has some nice features, some of which we
|
||||||
describe below.
|
describe below.
|
||||||
|
|
||||||
Transforming Compute Graphs
|
Transforming Compute Graphs
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
@@ -109,14 +109,14 @@ Here is a concrete example:
|
|||||||
|
|
||||||
An important behavior to be aware of is when the graph will be implicitly
|
An important behavior to be aware of is when the graph will be implicitly
|
||||||
evaluated. Anytime you ``print`` an array, convert it to an
|
evaluated. Anytime you ``print`` an array, convert it to an
|
||||||
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
|
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||||
saving functions) will also evaluate the array.
|
saving functions) will also evaluate the array.
|
||||||
|
|
||||||
|
|
||||||
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
||||||
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
||||||
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||||
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
||||||
will be a partial evaluation, computing only the forward pass.
|
will be a partial evaluation, computing only the forward pass.
|
||||||
|
|
||||||
|
@@ -3,10 +3,10 @@
|
|||||||
Conversion to NumPy and Other Frameworks
|
Conversion to NumPy and Other Frameworks
|
||||||
========================================
|
========================================
|
||||||
|
|
||||||
MLX array supports conversion between other frameworks with either:
|
MLX array supports conversion between other frameworks with either:
|
||||||
|
|
||||||
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||||
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||||
|
|
||||||
Let's convert an array to NumPy and back.
|
Let's convert an array to NumPy and back.
|
||||||
|
|
||||||
@@ -66,7 +66,7 @@ even though no in-place operations on MLX memory are executed.
|
|||||||
PyTorch
|
PyTorch
|
||||||
-------
|
-------
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||||
|
@@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
|
|||||||
and :func:`jvp` for Jacobian-vector products.
|
and :func:`jvp` for Jacobian-vector products.
|
||||||
|
|
||||||
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
||||||
gradient with respect to the function's input.
|
gradient with respect to the function's input.
|
||||||
|
@@ -8,33 +8,33 @@ Saving and Loading Arrays
|
|||||||
MLX supports multiple array serialization formats.
|
MLX supports multiple array serialization formats.
|
||||||
|
|
||||||
.. list-table:: Serialization Formats
|
.. list-table:: Serialization Formats
|
||||||
:widths: 20 8 25 25
|
:widths: 20 8 25 25
|
||||||
:header-rows: 1
|
:header-rows: 1
|
||||||
|
|
||||||
* - Format
|
* - Format
|
||||||
- Extension
|
- Extension
|
||||||
- Function
|
- Function
|
||||||
- Notes
|
- Notes
|
||||||
* - NumPy
|
* - NumPy
|
||||||
- ``.npy``
|
- ``.npy``
|
||||||
- :func:`save`
|
- :func:`save`
|
||||||
- Single arrays only
|
- Single arrays only
|
||||||
* - NumPy archive
|
* - NumPy archive
|
||||||
- ``.npz``
|
- ``.npz``
|
||||||
- :func:`savez` and :func:`savez_compressed`
|
- :func:`savez` and :func:`savez_compressed`
|
||||||
- Multiple arrays
|
- Multiple arrays
|
||||||
* - Safetensors
|
* - Safetensors
|
||||||
- ``.safetensors``
|
- ``.safetensors``
|
||||||
- :func:`save_safetensors`
|
- :func:`save_safetensors`
|
||||||
- Multiple arrays
|
- Multiple arrays
|
||||||
* - GGUF
|
* - GGUF
|
||||||
- ``.gguf``
|
- ``.gguf``
|
||||||
- :func:`save_gguf`
|
- :func:`save_gguf`
|
||||||
- Multiple arrays
|
- Multiple arrays
|
||||||
|
|
||||||
The :func:`load` function will load any of the supported serialization
|
The :func:`load` function will load any of the supported serialization
|
||||||
formats. It determines the format from the extensions. The output of
|
formats. It determines the format from the extensions. The output of
|
||||||
:func:`load` depends on the format.
|
:func:`load` depends on the format.
|
||||||
|
|
||||||
Here's an example of saving a single array to a file:
|
Here's an example of saving a single array to a file:
|
||||||
|
|
||||||
|
@@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
|
|||||||
|
|
||||||
In MLX, rather than moving arrays to devices, you specify the device when you
|
In MLX, rather than moving arrays to devices, you specify the device when you
|
||||||
run the operation. Any device can perform any operation on ``a`` and ``b``
|
run the operation. Any device can perform any operation on ``a`` and ``b``
|
||||||
without needing to move them from one memory location to another. For example:
|
without needing to move them from one memory location to another. For example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@@ -11,14 +11,10 @@ option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
|||||||
|
|
||||||
# ----------------------------- Dependencies -----------------------------
|
# ----------------------------- Dependencies -----------------------------
|
||||||
find_package(MLX CONFIG REQUIRED)
|
find_package(MLX CONFIG REQUIRED)
|
||||||
find_package(
|
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||||
Python 3.8
|
|
||||||
COMPONENTS Interpreter Development.Module
|
|
||||||
REQUIRED)
|
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
||||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
|
||||||
OUTPUT_VARIABLE NB_DIR)
|
|
||||||
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
|
||||||
find_package(nanobind CONFIG REQUIRED)
|
find_package(nanobind CONFIG REQUIRED)
|
||||||
|
|
||||||
@@ -28,10 +24,16 @@ find_package(nanobind CONFIG REQUIRED)
|
|||||||
add_library(mlx_ext)
|
add_library(mlx_ext)
|
||||||
|
|
||||||
# Add sources
|
# Add sources
|
||||||
target_sources(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)
|
target_sources(
|
||||||
|
mlx_ext
|
||||||
|
PUBLIC
|
||||||
|
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
|
||||||
|
)
|
||||||
|
|
||||||
# Add include headers
|
# Add include headers
|
||||||
target_include_directories(mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR})
|
target_include_directories(
|
||||||
|
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
|
||||||
|
)
|
||||||
|
|
||||||
# Link to mlx
|
# Link to mlx
|
||||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||||
@@ -41,32 +43,27 @@ target_link_libraries(mlx_ext PUBLIC mlx)
|
|||||||
# Build metallib
|
# Build metallib
|
||||||
if(MLX_BUILD_METAL)
|
if(MLX_BUILD_METAL)
|
||||||
mlx_build_metallib(
|
mlx_build_metallib(
|
||||||
TARGET
|
TARGET mlx_ext_metallib
|
||||||
mlx_ext_metallib
|
TITLE mlx_ext
|
||||||
TITLE
|
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||||
mlx_ext
|
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
||||||
SOURCES
|
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
||||||
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
)
|
||||||
INCLUDE_DIRS
|
|
||||||
${PROJECT_SOURCE_DIR}
|
|
||||||
${MLX_INCLUDE_DIRS}
|
|
||||||
OUTPUT_DIRECTORY
|
|
||||||
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
|
||||||
|
|
||||||
add_dependencies(mlx_ext mlx_ext_metallib)
|
add_dependencies(
|
||||||
|
mlx_ext
|
||||||
|
mlx_ext_metallib
|
||||||
|
)
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ----------------------------- Python Bindings -----------------------------
|
# ----------------------------- Python Bindings -----------------------------
|
||||||
nanobind_add_module(
|
nanobind_add_module(
|
||||||
_ext
|
_ext
|
||||||
NB_STATIC
|
NB_STATIC STABLE_ABI LTO NOMINSIZE
|
||||||
STABLE_ABI
|
NB_DOMAIN mlx
|
||||||
LTO
|
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||||
NOMINSIZE
|
)
|
||||||
NB_DOMAIN
|
|
||||||
mlx
|
|
||||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
|
|
||||||
target_link_libraries(_ext PRIVATE mlx_ext)
|
target_link_libraries(_ext PRIVATE mlx_ext)
|
||||||
|
|
||||||
if(BUILD_SHARED_LIBS)
|
if(BUILD_SHARED_LIBS)
|
||||||
|
@@ -249,8 +249,9 @@ void Axpby::eval_gpu(
|
|||||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||||
kname << type_to_name(out);
|
kname << type_to_name(out);
|
||||||
|
|
||||||
// Make sure the metal library is available
|
// Make sure the metal library is available and look for it
|
||||||
d.register_library("mlx_ext");
|
// in the same folder as this executable if needed
|
||||||
|
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||||
|
|
||||||
// Make a kernel from this metal library
|
// Make a kernel from this metal library
|
||||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
requires = [
|
requires = [
|
||||||
"setuptools>=42",
|
"setuptools>=42",
|
||||||
"cmake>=3.24",
|
"cmake>=3.24",
|
||||||
"mlx>=0.18.0",
|
"mlx>=0.9.0",
|
||||||
"nanobind==2.2.0",
|
"nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4",
|
||||||
]
|
]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
setuptools>=42
|
setuptools>=42
|
||||||
cmake>=3.24
|
cmake>=3.24
|
||||||
mlx>=0.18.1
|
mlx>=0.9.0
|
||||||
nanobind==2.2.0
|
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||||
|
@@ -13,6 +13,7 @@ if __name__ == "__main__":
|
|||||||
cmdclass={"build_ext": extension.CMakeBuild},
|
cmdclass={"build_ext": extension.CMakeBuild},
|
||||||
packages=["mlx_sample_extensions"],
|
packages=["mlx_sample_extensions"],
|
||||||
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
|
||||||
|
extras_require={"dev": []},
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
)
|
)
|
||||||
|
@@ -1,24 +1,25 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||||
|
)
|
||||||
|
|
||||||
if(MLX_BUILD_CPU)
|
if (MLX_BUILD_CPU)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||||
@@ -26,15 +27,17 @@ endif()
|
|||||||
|
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||||
if(MLX_BUILD_ACCELERATE)
|
if (MLX_BUILD_ACCELERATE)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||||
elseif(MLX_BUILD_CPU)
|
elseif(MLX_BUILD_CPU)
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MLX_BUILD_METAL)
|
if (MLX_BUILD_METAL)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||||
|
@@ -23,22 +23,11 @@ void free(Buffer buffer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||||
void* ptr = std::malloc(size + sizeof(size_t));
|
return Buffer{std::malloc(size)};
|
||||||
if (ptr != nullptr) {
|
|
||||||
*static_cast<size_t*>(ptr) = size;
|
|
||||||
}
|
|
||||||
return Buffer{ptr};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommonAllocator::free(Buffer buffer) {
|
void CommonAllocator::free(Buffer buffer) {
|
||||||
std::free(buffer.ptr());
|
std::free(buffer.raw_ptr());
|
||||||
}
|
|
||||||
|
|
||||||
size_t CommonAllocator::size(Buffer buffer) const {
|
|
||||||
if (buffer.ptr() == nullptr) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return *static_cast<size_t*>(buffer.ptr());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer malloc_or_wait(size_t size) {
|
Buffer malloc_or_wait(size_t size) {
|
||||||
|
@@ -41,7 +41,6 @@ class Allocator {
|
|||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
virtual size_t size(Buffer buffer) const = 0;
|
|
||||||
|
|
||||||
Allocator() = default;
|
Allocator() = default;
|
||||||
Allocator(const Allocator& other) = delete;
|
Allocator(const Allocator& other) = delete;
|
||||||
@@ -58,7 +57,6 @@ class CommonAllocator : public Allocator {
|
|||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
virtual size_t size(Buffer buffer) const override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CommonAllocator() = default;
|
CommonAllocator() = default;
|
||||||
|
@@ -17,10 +17,6 @@ bool in_tracing() {
|
|||||||
return detail::InTracing::in_tracing();
|
return detail::InTracing::in_tracing();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool retain_graph() {
|
|
||||||
return detail::RetainGraph::retain_graph();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||||
@@ -95,34 +91,18 @@ void array::detach() {
|
|||||||
array_desc_->primitive = nullptr;
|
array_desc_->primitive = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool array::is_available() const {
|
|
||||||
if (status() == Status::available) {
|
|
||||||
return true;
|
|
||||||
} else if (status() == Status::evaluated && event().is_signaled()) {
|
|
||||||
set_status(Status::available);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void array::wait() {
|
|
||||||
if (!is_available()) {
|
|
||||||
event().wait();
|
|
||||||
set_status(Status::available);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void array::eval() {
|
void array::eval() {
|
||||||
// Ensure the array is ready to be read
|
// Ensure the array is ready to be read
|
||||||
if (status() == Status::unscheduled) {
|
if (status() == Status::scheduled) {
|
||||||
|
event().wait();
|
||||||
|
set_status(Status::available);
|
||||||
|
} else if (status() == Status::unscheduled) {
|
||||||
mlx::core::eval({*this});
|
mlx::core::eval({*this});
|
||||||
} else {
|
|
||||||
wait();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool array::is_tracer() const {
|
bool array::is_tracer() const {
|
||||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
return array_desc_->is_tracer && in_tracing();
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||||
@@ -178,10 +158,8 @@ void array::move_shared_buffer(
|
|||||||
array_desc_->flags = flags;
|
array_desc_->flags = flags;
|
||||||
array_desc_->data_size = data_size;
|
array_desc_->data_size = data_size;
|
||||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||||
auto data_ptr = other.array_desc_->data_ptr;
|
array_desc_->data_ptr = static_cast<void*>(
|
||||||
other.array_desc_->data_ptr = nullptr;
|
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||||
array_desc_->data_ptr =
|
|
||||||
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::move_shared_buffer(array other) {
|
void array::move_shared_buffer(array other) {
|
||||||
@@ -193,11 +171,10 @@ array::~array() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ignore arrays that might be detached during eval
|
// Ignore arrays that will be detached
|
||||||
if (status() == array::Status::scheduled) {
|
if (status() != array::Status::unscheduled) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Break circular reference for non-detached arrays with siblings
|
// Break circular reference for non-detached arrays with siblings
|
||||||
if (auto n = siblings().size(); n > 0) {
|
if (auto n = siblings().size(); n > 0) {
|
||||||
bool do_detach = true;
|
bool do_detach = true;
|
||||||
@@ -260,38 +237,25 @@ array::ArrayDesc::~ArrayDesc() {
|
|||||||
// This calls recursively the destructor and can result in stack overflow, we
|
// This calls recursively the destructor and can result in stack overflow, we
|
||||||
// instead put them in a vector and destroy them one at a time resulting in a
|
// instead put them in a vector and destroy them one at a time resulting in a
|
||||||
// max stack depth of 2.
|
// max stack depth of 2.
|
||||||
if (inputs.empty()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
|
std::vector<std::shared_ptr<ArrayDesc>> for_deletion;
|
||||||
|
|
||||||
auto append_deletable_inputs = [&for_deletion](ArrayDesc& ad) {
|
for (array& a : inputs) {
|
||||||
std::unordered_map<std::uintptr_t, array> input_map;
|
if (a.array_desc_.use_count() == 1) {
|
||||||
for (array& a : ad.inputs) {
|
for_deletion.push_back(std::move(a.array_desc_));
|
||||||
if (a.array_desc_) {
|
|
||||||
input_map.insert({a.id(), a});
|
|
||||||
for (auto& s : a.siblings()) {
|
|
||||||
input_map.insert({s.id(), s});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ad.inputs.clear();
|
}
|
||||||
for (auto& [_, a] : input_map) {
|
|
||||||
if (a.array_desc_.use_count() <= a.siblings().size() + 1) {
|
|
||||||
for_deletion.push_back(std::move(a.array_desc_));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
append_deletable_inputs(*this);
|
|
||||||
|
|
||||||
while (!for_deletion.empty()) {
|
while (!for_deletion.empty()) {
|
||||||
// top is going to be deleted at the end of the block *after* the arrays
|
// top is going to be deleted at the end of the block *after* the arrays
|
||||||
// with inputs have been moved into the vector
|
// with inputs have been moved into the vector
|
||||||
auto top = std::move(for_deletion.back());
|
auto top = std::move(for_deletion.back());
|
||||||
for_deletion.pop_back();
|
for_deletion.pop_back();
|
||||||
append_deletable_inputs(*top);
|
|
||||||
|
for (array& a : top->inputs) {
|
||||||
|
if (a.array_desc_.use_count() == 1) {
|
||||||
|
for_deletion.push_back(std::move(a.array_desc_));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
64
mlx/array.h
64
mlx/array.h
@@ -5,6 +5,7 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
@@ -219,23 +220,11 @@ class array {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Flags {
|
struct Flags {
|
||||||
// True iff there are no gaps in the underlying data. Each item
|
// True if there are no gaps in the underlying data. Each item
|
||||||
// in the underlying data buffer belongs to at least one index.
|
// in the underlying data buffer belongs to at least one index.
|
||||||
//
|
|
||||||
// True iff:
|
|
||||||
// prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size()
|
|
||||||
bool contiguous : 1;
|
bool contiguous : 1;
|
||||||
|
|
||||||
// True iff:
|
|
||||||
// strides[-1] == 1 and
|
|
||||||
// all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in
|
|
||||||
// range(ndim - 1))
|
|
||||||
bool row_contiguous : 1;
|
bool row_contiguous : 1;
|
||||||
|
|
||||||
// True iff:
|
|
||||||
// strides[0] == 1 and
|
|
||||||
// all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in
|
|
||||||
// range(1, ndim))
|
|
||||||
bool col_contiguous : 1;
|
bool col_contiguous : 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -303,16 +292,7 @@ class array {
|
|||||||
return array_desc_->flags;
|
return array_desc_->flags;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The size (in elements) of the underlying buffer the array points to.
|
/** The size (in elements) of the underlying buffer the array points to. */
|
||||||
*
|
|
||||||
* This can be different than the actual size of the array if the array has
|
|
||||||
* been broadcast or irregularly strided. If ``first`` is the offset into
|
|
||||||
* the data buffer of the first element of the array (i.e. the offset
|
|
||||||
* corresponding to ``arr[0, 0, ...]``) and last is the offset into the
|
|
||||||
* data buffer of the last element of the array (i.e. the offset
|
|
||||||
* corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``.
|
|
||||||
* Note, ``data_size`` is in units of ``item_size`` (not bytes).
|
|
||||||
**/
|
|
||||||
size_t data_size() const {
|
size_t data_size() const {
|
||||||
return array_desc_->data_size;
|
return array_desc_->data_size;
|
||||||
}
|
}
|
||||||
@@ -324,10 +304,6 @@ class array {
|
|||||||
return array_desc_->data->buffer;
|
return array_desc_->data->buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t buffer_size() const {
|
|
||||||
return allocator::allocator().size(buffer());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return a copy of the shared pointer
|
// Return a copy of the shared pointer
|
||||||
// to the array::Data struct
|
// to the array::Data struct
|
||||||
std::shared_ptr<Data> data_shared_ptr() const {
|
std::shared_ptr<Data> data_shared_ptr() const {
|
||||||
@@ -344,33 +320,11 @@ class array {
|
|||||||
return static_cast<T*>(array_desc_->data_ptr);
|
return static_cast<T*>(array_desc_->data_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Status {
|
enum Status { unscheduled, scheduled, available };
|
||||||
// The ouptut of a computation which has not been scheduled.
|
|
||||||
// For example, the status of `x` in `auto x = a + b`.
|
|
||||||
unscheduled,
|
|
||||||
|
|
||||||
// The ouptut of a computation which has been scheduled but `eval_*` has
|
bool is_available() const {
|
||||||
// not yet been called on the array's primitive. A possible
|
return status() == Status::available;
|
||||||
// status of `x` in `auto x = a + b; eval(x);`
|
}
|
||||||
scheduled,
|
|
||||||
|
|
||||||
// The array's `eval_*` function has been run, but the computation is not
|
|
||||||
// necessarily complete. The array will have memory allocated and if it is
|
|
||||||
// not a tracer then it will be detached from the graph.
|
|
||||||
evaluated,
|
|
||||||
|
|
||||||
// If the array is the output of a computation then the computation
|
|
||||||
// is complete. Constant arrays are always available (e.g. `array({1, 2,
|
|
||||||
// 3})`)
|
|
||||||
available
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if the array is safe to read.
|
|
||||||
bool is_available() const;
|
|
||||||
|
|
||||||
// Wait on the array to be available. After this `is_available` returns
|
|
||||||
// `true`.
|
|
||||||
void wait();
|
|
||||||
|
|
||||||
Status status() const {
|
Status status() const {
|
||||||
return array_desc_->status;
|
return array_desc_->status;
|
||||||
@@ -459,6 +413,8 @@ class array {
|
|||||||
void* data_ptr{nullptr};
|
void* data_ptr{nullptr};
|
||||||
|
|
||||||
// The size in elements of the data buffer the array accesses
|
// The size in elements of the data buffer the array accesses
|
||||||
|
// This can be different than the actual size of the array if it
|
||||||
|
// has been broadcast or irregularly strided.
|
||||||
size_t data_size;
|
size_t data_size;
|
||||||
|
|
||||||
// Contains useful meta data about the array
|
// Contains useful meta data about the array
|
||||||
@@ -610,4 +566,6 @@ inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
|
|||||||
template <typename... T>
|
template <typename... T>
|
||||||
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
|
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
|
||||||
|
|
||||||
|
enum QuantizationMode { DEFAULT, NF4 };
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -1,8 +1,10 @@
|
|||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
|
)
|
||||||
|
@@ -1,9 +1,9 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
|
||||||
#include <simd/vector.h>
|
#include <simd/vector.h>
|
||||||
|
#include <vecLib/vDSP.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <vecLib/BNNS/bnns.h>
|
||||||
|
#include <vecLib/cblas_new.h>
|
||||||
|
|
||||||
#include "mlx/backend/accelerate/utils.h"
|
#include "mlx/backend/accelerate/utils.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
|
@@ -3,7 +3,8 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <vecLib/vDSP.h>
|
||||||
|
#include <vecLib/vForce.h>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/binary.h"
|
#include "mlx/backend/common/binary.h"
|
||||||
@@ -36,7 +37,7 @@ DEFAULT(Ceil)
|
|||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
DEFAULT(Conjugate)
|
DEFAULT(Conjugate)
|
||||||
DEFAULT(Copy)
|
DEFAULT(Copy)
|
||||||
DEFAULT_MULTI(CustomTransforms)
|
DEFAULT_MULTI(CustomVJP)
|
||||||
DEFAULT_MULTI(Depends)
|
DEFAULT_MULTI(Depends)
|
||||||
DEFAULT_MULTI(DivMod)
|
DEFAULT_MULTI(DivMod)
|
||||||
DEFAULT(NumberOfElements)
|
DEFAULT(NumberOfElements)
|
||||||
@@ -50,7 +51,6 @@ DEFAULT(GatherMM)
|
|||||||
DEFAULT(GatherQMM)
|
DEFAULT(GatherQMM)
|
||||||
DEFAULT(Greater)
|
DEFAULT(Greater)
|
||||||
DEFAULT(GreaterEqual)
|
DEFAULT(GreaterEqual)
|
||||||
DEFAULT(Hadamard)
|
|
||||||
DEFAULT(Less)
|
DEFAULT(Less)
|
||||||
DEFAULT(LessEqual)
|
DEFAULT(LessEqual)
|
||||||
DEFAULT(Load)
|
DEFAULT(Load)
|
||||||
@@ -81,7 +81,6 @@ DEFAULT_MULTI(SVD)
|
|||||||
DEFAULT(Transpose)
|
DEFAULT(Transpose)
|
||||||
DEFAULT(Inverse)
|
DEFAULT(Inverse)
|
||||||
DEFAULT(Cholesky)
|
DEFAULT(Cholesky)
|
||||||
DEFAULT_MULTI(Eigh)
|
|
||||||
|
|
||||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
@@ -103,7 +102,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
|
|
||||||
if (a.dtype() == float32) {
|
if (a.dtype() == float32) {
|
||||||
binary_op<float>(
|
binary(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@@ -118,7 +117,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else if (a.dtype() == int32) {
|
} else if (a.dtype() == int32) {
|
||||||
binary_op<int>(
|
binary(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@@ -133,7 +132,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -288,7 +287,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
|
|
||||||
if (a.dtype() == int32) {
|
if (a.dtype() == int32) {
|
||||||
binary_op<int>(
|
binary(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@@ -301,7 +300,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else if (a.dtype() == float32) {
|
} else if (a.dtype() == float32) {
|
||||||
binary_op<float>(
|
binary(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@@ -316,7 +315,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -327,8 +326,12 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
set_unary_output_data(in, out);
|
set_unary_output_data(in, out);
|
||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
|
} else if (issubdtype(out.dtype(), inexact)) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
throw std::invalid_argument(
|
||||||
|
"[exp] Cannot exponentiate elements in array"
|
||||||
|
" with non floating point type.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -390,8 +393,12 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
vvlog1pf(
|
vvlog1pf(
|
||||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||||
|
} else if (issubdtype(out.dtype(), inexact)) {
|
||||||
|
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
throw std::invalid_argument(
|
||||||
|
"[log1p] Cannot compute log of elements in array with"
|
||||||
|
" non floating point type.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -401,7 +408,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
|
|
||||||
if (a.dtype() == float32) {
|
if (a.dtype() == float32) {
|
||||||
binary_op<float>(
|
binary(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@@ -416,7 +423,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -427,7 +434,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
set_unary_output_data(in, out);
|
set_unary_output_data(in, out);
|
||||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
unary(in, out, [](auto x) { return -x; });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -514,7 +521,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
unary(in, out, [](auto x) { return x * x; });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -540,7 +547,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& b = inputs[1];
|
auto& b = inputs[1];
|
||||||
|
|
||||||
if (a.dtype() == float32) {
|
if (a.dtype() == float32) {
|
||||||
binary_op<float>(
|
binary(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@@ -558,7 +565,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||||
});
|
});
|
||||||
} else if (a.dtype() == int32) {
|
} else if (a.dtype() == int32) {
|
||||||
binary_op<int>(
|
binary(
|
||||||
a,
|
a,
|
||||||
b,
|
b,
|
||||||
out,
|
out,
|
||||||
@@ -570,7 +577,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
},
|
},
|
||||||
UseDefaultBinaryOp());
|
UseDefaultBinaryOp());
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -18,61 +18,49 @@ void _qmm_t_4_64(
|
|||||||
const float* biases,
|
const float* biases,
|
||||||
int M,
|
int M,
|
||||||
int N,
|
int N,
|
||||||
int K,
|
int K) {
|
||||||
int B,
|
|
||||||
bool batched_w) {
|
|
||||||
constexpr int bits = 4;
|
constexpr int bits = 4;
|
||||||
constexpr int group_size = 64;
|
constexpr int group_size = 64;
|
||||||
constexpr int bitmask = (1 << bits) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = 32 / bits;
|
||||||
constexpr int packs_in_group = group_size / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
|
|
||||||
int w_els = N * K / pack_factor;
|
for (int m = 0; m < M; m++) {
|
||||||
int g_els = w_els * pack_factor / group_size;
|
const uint32_t* w_local = w;
|
||||||
|
const float* scales_local = scales;
|
||||||
|
const float* biases_local = biases;
|
||||||
|
|
||||||
for (int i = 0; i < B; i++) {
|
for (int n = 0; n < N; n++) {
|
||||||
for (int m = 0; m < M; m++) {
|
const simd_float16* x_local = (simd_float16*)x;
|
||||||
const uint32_t* w_local = w;
|
simd_float16 sum = 0;
|
||||||
const float* scales_local = scales;
|
for (int k = 0; k < K; k += group_size) {
|
||||||
const float* biases_local = biases;
|
float scale = *scales_local++;
|
||||||
|
float bias = *biases_local++;
|
||||||
|
|
||||||
for (int n = 0; n < N; n++) {
|
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||||
const simd_float16* x_local = (simd_float16*)x;
|
// TODO: vectorize this properly
|
||||||
simd_float16 sum = 0;
|
simd_uint16 wi;
|
||||||
for (int k = 0; k < K; k += group_size) {
|
for (int e = 0; e < 2; e++) {
|
||||||
float scale = *scales_local++;
|
uint32_t wii = *w_local++;
|
||||||
float bias = *biases_local++;
|
for (int p = 0; p < 8; p++) {
|
||||||
|
wi[e * 8 + p] = wii & bitmask;
|
||||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
wii >>= bits;
|
||||||
// TODO: vectorize this properly
|
|
||||||
simd_uint16 wi;
|
|
||||||
for (int e = 0; e < 2; e++) {
|
|
||||||
uint32_t wii = *w_local++;
|
|
||||||
for (int p = 0; p < 8; p++) {
|
|
||||||
wi[e * 8 + p] = wii & bitmask;
|
|
||||||
wii >>= bits;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
simd_float16 wf = simd_float(wi);
|
|
||||||
wf *= scale;
|
|
||||||
wf += bias;
|
|
||||||
|
|
||||||
sum += (*x_local) * wf;
|
|
||||||
x_local++;
|
|
||||||
}
|
}
|
||||||
}
|
simd_float16 wf = simd_float(wi);
|
||||||
|
wf *= scale;
|
||||||
|
wf += bias;
|
||||||
|
|
||||||
*result = simd_reduce_add(sum);
|
sum += (*x_local) * wf;
|
||||||
result++;
|
x_local++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
x += K;
|
*result = simd_reduce_add(sum);
|
||||||
}
|
result++;
|
||||||
if (batched_w) {
|
|
||||||
w += w_els;
|
|
||||||
scales += g_els;
|
|
||||||
biases += g_els;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
x += K;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,10 +82,8 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (condition) {
|
if (condition) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
int K = x.shape(-1);
|
int K = x.shape(-1);
|
||||||
int M = x.shape(-2);
|
int M = x.size() / K;
|
||||||
int N = out.shape(-1);
|
int N = out.shape(-1);
|
||||||
int B = x.size() / K / M;
|
|
||||||
bool batched_w = w.ndim() > 2;
|
|
||||||
_qmm_t_4_64(
|
_qmm_t_4_64(
|
||||||
out.data<float>(),
|
out.data<float>(),
|
||||||
x.data<float>(),
|
x.data<float>(),
|
||||||
@@ -106,9 +92,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
biases.data<float>(),
|
biases.data<float>(),
|
||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K);
|
||||||
B,
|
|
||||||
batched_w);
|
|
||||||
} else {
|
} else {
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
|
||||||
#include <simd/vector.h>
|
#include <simd/vector.h>
|
||||||
|
#include <vecLib/vDSP.h>
|
||||||
|
|
||||||
#include "mlx/backend/common/reduce.h"
|
#include "mlx/backend/common/reduce.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
@@ -3,10 +3,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
#include <arm_neon.h>
|
#include <arm_neon.h>
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <simd/math.h>
|
#include <simd/math.h>
|
||||||
#include <simd/vector.h>
|
#include <simd/vector.h>
|
||||||
|
|
||||||
@@ -33,8 +30,8 @@ namespace {
|
|||||||
* Note: The implementation below is a general fast exp. There could be faster
|
* Note: The implementation below is a general fast exp. There could be faster
|
||||||
* implementations for numbers strictly < 0.
|
* implementations for numbers strictly < 0.
|
||||||
*/
|
*/
|
||||||
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
|
inline simd_float16 simd_fast_exp(simd_float16 x) {
|
||||||
auto x = x_init * 1.442695; // multiply with log_2(e)
|
x *= 1.442695; // multiply with log_2(e)
|
||||||
simd_float16 ipart, fpart;
|
simd_float16 ipart, fpart;
|
||||||
simd_int16 epart;
|
simd_int16 epart;
|
||||||
x = simd_clamp(x, -80, 80);
|
x = simd_clamp(x, -80, 80);
|
||||||
@@ -53,30 +50,28 @@ inline simd_float16 simd_fast_exp(simd_float16 x_init) {
|
|||||||
// bitshifting
|
// bitshifting
|
||||||
epart = (simd_int(ipart) + 127) << 23;
|
epart = (simd_int(ipart) + 127) << 23;
|
||||||
|
|
||||||
// Avoid supressing NaNs
|
return (*(simd_float16*)&epart) * x;
|
||||||
simd_int16 eq = (x_init == x_init);
|
|
||||||
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
/**
|
/**
|
||||||
* The ARM neon equivalent of the fast exp above.
|
* The ARM neon equivalent of the fast exp above.
|
||||||
*/
|
*/
|
||||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||||
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
|
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
|
||||||
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
|
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
|
||||||
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
|
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
|
||||||
|
|
||||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
|
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
|
||||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||||
|
|
||||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
x = vdupq_n_f16(1.535336188319500e-4f);
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
|
||||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
|
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
|
||||||
|
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
|
||||||
|
|
||||||
// generate 2**ipart in the floating point representation using integer
|
// generate 2**ipart in the floating point representation using integer
|
||||||
// bitshifting
|
// bitshifting
|
||||||
@@ -112,55 +107,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
|
|||||||
return vget_lane_f16(y, 0);
|
return vget_lane_f16(y, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename VT>
|
|
||||||
struct NeonFp16SimdOps {
|
|
||||||
VT init(T a) {
|
|
||||||
return vdupq_n_f16(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT load(const T* a) {
|
|
||||||
return vld1q_f16(a);
|
|
||||||
}
|
|
||||||
|
|
||||||
void store(T* dst, VT x) {
|
|
||||||
vst1q_f16(dst, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT max(VT a, VT b) {
|
|
||||||
return vmaxq_f16(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT exp(VT x) {
|
|
||||||
return neon_fast_exp(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT add(VT a, VT b) {
|
|
||||||
return vaddq_f16(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT sub(VT a, T b) {
|
|
||||||
return vsubq_f16(a, vdupq_n_f16(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
VT mul(VT a, VT b) {
|
|
||||||
return vmulq_f16(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
VT mul(VT a, T b) {
|
|
||||||
return vmulq_f16(a, vdupq_n_f16(b));
|
|
||||||
}
|
|
||||||
|
|
||||||
T reduce_max(VT x) {
|
|
||||||
return neon_reduce_max(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
T reduce_add(VT x) {
|
|
||||||
return neon_reduce_add(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
|
|
||||||
template <typename T, typename VT>
|
template <typename T, typename VT>
|
||||||
struct AccelerateSimdOps {
|
struct AccelerateSimdOps {
|
||||||
VT init(T a) {
|
VT init(T a) {
|
||||||
@@ -208,6 +154,53 @@ struct AccelerateSimdOps {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T, typename VT>
|
||||||
|
struct NeonFp16SimdOps {
|
||||||
|
VT init(T a) {
|
||||||
|
return vdupq_n_f16(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT load(const T* a) {
|
||||||
|
return vld1q_f16(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
void store(T* dst, VT x) {
|
||||||
|
vst1q_f16(dst, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT max(VT a, VT b) {
|
||||||
|
return vmaxq_f16(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT exp(VT x) {
|
||||||
|
return neon_fast_exp(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT add(VT a, VT b) {
|
||||||
|
return vaddq_f16(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT sub(VT a, T b) {
|
||||||
|
return vsubq_f16(a, vdupq_n_f16(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
VT mul(VT a, VT b) {
|
||||||
|
return vmulq_f16(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
VT mul(VT a, T b) {
|
||||||
|
return vmulq_f16(a, vdupq_n_f16(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
T reduce_max(VT x) {
|
||||||
|
return neon_reduce_max(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
T reduce_add(VT x) {
|
||||||
|
return neon_reduce_add(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||||
void softmax(const array& in, array& out) {
|
void softmax(const array& in, array& out) {
|
||||||
Ops ops;
|
Ops ops;
|
||||||
@@ -369,16 +362,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
AccelerateSimdOps<float, simd_float16>,
|
AccelerateSimdOps<float, simd_float16>,
|
||||||
16>(in, out);
|
16>(in, out);
|
||||||
} else {
|
} else {
|
||||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
softmax<
|
softmax<
|
||||||
float16_t,
|
float16_t,
|
||||||
float16_t,
|
float16_t,
|
||||||
float16x8_t,
|
float16x8_t,
|
||||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||||
8>(in, out);
|
8>(in, out);
|
||||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
eval(inputs, out); // Redirect to common backend for consistency
|
|
||||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <vecLib/BNNS/bnns.h>
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
|
||||||
|
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||||
set(COMPILER ${CMAKE_C_COMPILER})
|
set(COMPILER ${CMAKE_C_COMPILER})
|
||||||
set(CLANG TRUE)
|
set(CLANG TRUE)
|
||||||
else()
|
else()
|
||||||
@@ -6,57 +7,71 @@ else()
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
OUTPUT compiled_preamble.cpp
|
OUTPUT compiled_preamble.cpp
|
||||||
COMMAND
|
COMMAND /bin/bash
|
||||||
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||||
${PROJECT_SOURCE_DIR} ${CLANG}
|
${COMPILER}
|
||||||
DEPENDS make_compiled_preamble.sh
|
${PROJECT_SOURCE_DIR}
|
||||||
compiled_preamble.h
|
${CLANG}
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
|
||||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
|
||||||
ops.h)
|
|
||||||
|
|
||||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
DEPENDS make_compiled_preamble.sh
|
||||||
|
compiled_preamble.h
|
||||||
|
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||||
|
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||||
|
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||||
|
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||||
|
ops.h
|
||||||
|
)
|
||||||
|
|
||||||
|
add_custom_target(
|
||||||
|
cpu_compiled_preamble
|
||||||
|
DEPENDS compiled_preamble.cpp
|
||||||
|
)
|
||||||
|
|
||||||
add_dependencies(mlx cpu_compiled_preamble)
|
add_dependencies(mlx cpu_compiled_preamble)
|
||||||
|
|
||||||
target_sources(
|
target_sources(
|
||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
PRIVATE
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
)
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
|
||||||
|
|
||||||
if(IOS)
|
if (IOS)
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
|
||||||
|
)
|
||||||
else()
|
else()
|
||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
|
target_sources(
|
||||||
|
mlx
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
@@ -43,15 +43,13 @@ void set_binary_op_output_data(
|
|||||||
array& out,
|
array& out,
|
||||||
BinaryOpType bopt,
|
BinaryOpType bopt,
|
||||||
bool donate_with_move = false) {
|
bool donate_with_move = false) {
|
||||||
bool b_donatable = is_donatable(b, out);
|
|
||||||
bool a_donatable = is_donatable(a, out);
|
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::ScalarScalar:
|
case BinaryOpType::ScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
if (b_donatable) {
|
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(b);
|
out.move_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
@@ -66,7 +64,7 @@ void set_binary_op_output_data(
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorScalar:
|
case BinaryOpType::VectorScalar:
|
||||||
if (a_donatable) {
|
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(a);
|
out.move_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
@@ -81,13 +79,13 @@ void set_binary_op_output_data(
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorVector:
|
case BinaryOpType::VectorVector:
|
||||||
if (a_donatable) {
|
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(a);
|
out.move_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
}
|
}
|
||||||
} else if (b_donatable) {
|
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(b);
|
out.move_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
@@ -102,14 +100,16 @@ void set_binary_op_output_data(
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::General:
|
case BinaryOpType::General:
|
||||||
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
if (a.is_donatable() && a.flags().row_contiguous &&
|
||||||
|
a.itemsize() == out.itemsize() && a.size() == out.size()) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(a);
|
out.move_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
}
|
}
|
||||||
} else if (
|
} else if (
|
||||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
b.is_donatable() && b.flags().row_contiguous &&
|
||||||
|
b.itemsize() == out.itemsize() && b.size() == out.size()) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(b);
|
out.move_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
@@ -122,7 +122,19 @@ void set_binary_op_output_data(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct UseDefaultBinaryOp {};
|
struct UseDefaultBinaryOp {
|
||||||
|
template <typename T, typename U>
|
||||||
|
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||||
|
// Should we throw? This should normally never be called.
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||||
|
// Should we throw? This should normally never be called.
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
struct DefaultVectorScalar {
|
struct DefaultVectorScalar {
|
||||||
@@ -138,6 +150,18 @@ struct DefaultVectorScalar {
|
|||||||
a++;
|
a++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||||
|
T scalar = *b;
|
||||||
|
while (size-- > 0) {
|
||||||
|
auto dst = op(*a, scalar);
|
||||||
|
*dst_a = dst.first;
|
||||||
|
*dst_b = dst.second;
|
||||||
|
dst_a++;
|
||||||
|
dst_b++;
|
||||||
|
a++;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
@@ -154,6 +178,18 @@ struct DefaultScalarVector {
|
|||||||
b++;
|
b++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||||
|
T scalar = *a;
|
||||||
|
while (size-- > 0) {
|
||||||
|
auto dst = op(scalar, *b);
|
||||||
|
*dst_a = dst.first;
|
||||||
|
*dst_b = dst.second;
|
||||||
|
dst_a++;
|
||||||
|
dst_b++;
|
||||||
|
b++;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
@@ -170,110 +206,204 @@ struct DefaultVectorVector {
|
|||||||
b++;
|
b++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void operator()(const T* a, const T* b, U* dst_a, U* dst_b, int size) {
|
||||||
|
while (size-- > 0) {
|
||||||
|
auto dst = op(*a, *b);
|
||||||
|
*dst_a = dst.first;
|
||||||
|
*dst_b = dst.second;
|
||||||
|
dst_a++;
|
||||||
|
dst_b++;
|
||||||
|
a++;
|
||||||
|
b++;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int D, bool Strided>
|
template <typename T, typename U, typename Op>
|
||||||
void binary_op_dims(
|
void binary_op_dims1(const array& a, const array& b, array& out, Op op) {
|
||||||
const T* a,
|
const T* a_ptr = a.data<T>();
|
||||||
const T* b,
|
const T* b_ptr = b.data<T>();
|
||||||
U* out,
|
U* dst = out.data<U>();
|
||||||
Op op,
|
size_t a_idx = 0;
|
||||||
const std::vector<int>& shape,
|
size_t b_idx = 0;
|
||||||
const std::vector<size_t>& a_strides,
|
for (size_t i = 0; i < out.size(); ++i) {
|
||||||
const std::vector<size_t>& b_strides,
|
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
const std::vector<size_t>& out_strides,
|
a_idx += a.strides()[0];
|
||||||
int axis) {
|
b_idx += b.strides()[0];
|
||||||
auto stride_a = a_strides[axis];
|
|
||||||
auto stride_b = b_strides[axis];
|
|
||||||
auto stride_out = out_strides[axis];
|
|
||||||
auto N = shape[axis];
|
|
||||||
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
if constexpr (D > 1) {
|
|
||||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
|
||||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
|
||||||
} else {
|
|
||||||
if constexpr (Strided) {
|
|
||||||
op(a, b, out, stride_out);
|
|
||||||
} else {
|
|
||||||
*out = op(*a, *b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out += stride_out;
|
|
||||||
a += stride_a;
|
|
||||||
b += stride_b;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U, bool Strided, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims1(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
int stride) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; i++) {
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||||
|
a_idx += a.strides()[0];
|
||||||
|
b_idx += b.strides()[0];
|
||||||
|
dst += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims2(const array& a, const array& b, array& out, Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
a_idx += a.strides()[1];
|
||||||
|
b_idx += b.strides()[1];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims2(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op,
|
||||||
|
int stride) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||||
|
a_idx += a.strides()[1];
|
||||||
|
b_idx += b.strides()[1];
|
||||||
|
dst += stride;
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims3(const array& a, const array& b, array& out, Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||||
|
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
a_idx += a.strides()[2];
|
||||||
|
b_idx += b.strides()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||||
|
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims4(const array& a, const array& b, array& out, Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||||
|
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||||
|
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
a_idx += a.strides()[3];
|
||||||
|
b_idx += b.strides()[3];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||||
|
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||||
|
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dispatch_dims(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out,
|
||||||
|
Op op) {
|
||||||
|
switch (out.ndim()) {
|
||||||
|
case 1:
|
||||||
|
binary_op_dims1<T, U, Op>(a, b, out, op);
|
||||||
|
return;
|
||||||
|
case 2:
|
||||||
|
binary_op_dims2<T, U, Op>(a, b, out, op);
|
||||||
|
return;
|
||||||
|
case 3:
|
||||||
|
binary_op_dims3<T, U, Op>(a, b, out, op);
|
||||||
|
return;
|
||||||
|
case 4:
|
||||||
|
binary_op_dims4<T, U, Op>(a, b, out, op);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
for (size_t i = 0; i < out.size(); i++) {
|
||||||
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
|
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||||
|
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
void binary_op_dispatch_dims(
|
void binary_op_dispatch_dims(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
array& out,
|
array& out,
|
||||||
Op op,
|
Op op,
|
||||||
int dim,
|
int dim,
|
||||||
const std::vector<int>& shape,
|
int stride) {
|
||||||
const std::vector<size_t>& a_strides,
|
// Number of dimensions to loop over for vectorized ops
|
||||||
const std::vector<size_t>& b_strides,
|
|
||||||
const std::vector<size_t>& out_strides) {
|
|
||||||
const T* a_ptr = a.data<T>();
|
|
||||||
const T* b_ptr = b.data<T>();
|
|
||||||
U* out_ptr = out.data<U>();
|
|
||||||
switch (dim) {
|
switch (dim) {
|
||||||
case 1:
|
case 1:
|
||||||
binary_op_dims<T, U, Op, 1, Strided>(
|
binary_op_dims1<T, U, Op>(a, b, out, op, stride);
|
||||||
a_ptr,
|
|
||||||
b_ptr,
|
|
||||||
out_ptr,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
out_strides,
|
|
||||||
0);
|
|
||||||
return;
|
return;
|
||||||
case 2:
|
case 2:
|
||||||
binary_op_dims<T, U, Op, 2, Strided>(
|
binary_op_dims2<T, U, Op>(a, b, out, op, stride);
|
||||||
a_ptr,
|
|
||||||
b_ptr,
|
|
||||||
out_ptr,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
out_strides,
|
|
||||||
0);
|
|
||||||
return;
|
|
||||||
case 3:
|
|
||||||
binary_op_dims<T, U, Op, 3, Strided>(
|
|
||||||
a_ptr,
|
|
||||||
b_ptr,
|
|
||||||
out_ptr,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
out_strides,
|
|
||||||
0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3);
|
const T* a_ptr = a.data<T>();
|
||||||
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3);
|
const T* b_ptr = b.data<T>();
|
||||||
size_t stride = out_strides[dim - 4];
|
U* dst = out.data<U>();
|
||||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
for (size_t i = 0; i < out.size(); i += stride) {
|
||||||
binary_op_dims<T, U, Op, 3, Strided>(
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
a_ptr + a_it.loc,
|
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||||
b_ptr + b_it.loc,
|
op(a_ptr + a_idx, b_ptr + b_idx, dst, stride);
|
||||||
out_ptr + elem,
|
dst += stride;
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
out_strides,
|
|
||||||
dim - 3);
|
|
||||||
a_it.step();
|
|
||||||
b_it.step();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -320,33 +450,29 @@ void binary_op(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// General computation so let's try to optimize
|
// General computation so let's try to optimize
|
||||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
|
||||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
|
||||||
const auto& a_strides = new_strides[0];
|
|
||||||
const auto& b_strides = new_strides[1];
|
|
||||||
const auto& strides = new_strides[2];
|
|
||||||
|
|
||||||
// Get the left-most dim such that the array is row contiguous after
|
// Get the left-most dim such that the array is row contiguous after
|
||||||
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) {
|
auto& strides = out.strides();
|
||||||
int d = arr_strides.size() - 1;
|
auto leftmost_rc_dim = [&strides](const array& arr) {
|
||||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
int d = arr.ndim() - 1;
|
||||||
|
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
|
||||||
}
|
}
|
||||||
return d + 1;
|
return d + 1;
|
||||||
};
|
};
|
||||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
auto a_rc_dim = leftmost_rc_dim(a);
|
||||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
auto b_rc_dim = leftmost_rc_dim(b);
|
||||||
|
|
||||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||||
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) {
|
auto leftmost_s_dim = [](const array& arr) {
|
||||||
int d = arr_strides.size() - 1;
|
int d = arr.ndim() - 1;
|
||||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
for (; d >= 0 && arr.strides()[d] == 0; d--) {
|
||||||
}
|
}
|
||||||
return d + 1;
|
return d + 1;
|
||||||
};
|
};
|
||||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
auto a_s_dim = leftmost_s_dim(a);
|
||||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
auto b_s_dim = leftmost_s_dim(b);
|
||||||
|
|
||||||
auto ndim = new_shape.size();
|
auto ndim = out.ndim();
|
||||||
|
|
||||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||||
int dim = ndim;
|
int dim = ndim;
|
||||||
@@ -368,27 +494,27 @@ void binary_op(
|
|||||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||||
// contiguous methods above. Except for the case that the flags do not
|
// contiguous methods above. Except for the case that the flags do not
|
||||||
// correspond to the underlying contiguity.
|
// correspond to the underlying contiguity.
|
||||||
|
size_t stride;
|
||||||
if (dim == 0 || strides[dim - 1] < 16) {
|
if (dim == 0 || strides[dim - 1] < 16) {
|
||||||
|
stride = 1;
|
||||||
bopt = BinaryOpType::General;
|
bopt = BinaryOpType::General;
|
||||||
dim = ndim;
|
dim = ndim;
|
||||||
|
} else {
|
||||||
|
stride = strides[dim - 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::VectorVector:
|
case BinaryOpType::VectorVector:
|
||||||
binary_op_dispatch_dims<T, U, true>(
|
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
|
||||||
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
|
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorScalar:
|
case BinaryOpType::VectorScalar:
|
||||||
binary_op_dispatch_dims<T, U, true>(
|
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
|
||||||
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
|
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
binary_op_dispatch_dims<T, U, true>(
|
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
|
||||||
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
binary_op_dispatch_dims<T, U, false>(
|
binary_op_dispatch_dims<T, U>(a, b, out, op);
|
||||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -405,9 +531,9 @@ void binary_op(
|
|||||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||||
// with template specializations and overloading. Would it be simpler?
|
// with template specializations and overloading. Would it be simpler?
|
||||||
|
|
||||||
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||||
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||||
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||||
binary_op<T, T>(
|
binary_op<T, T>(
|
||||||
a,
|
a,
|
||||||
@@ -428,8 +554,7 @@ void binary_op(
|
|||||||
DefaultVectorScalar<T, T, Op>(op),
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
opvv);
|
opvv);
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
value) {
|
|
||||||
// opsv and opvv were UseDefaultBinaryOp
|
// opsv and opvv were UseDefaultBinaryOp
|
||||||
binary_op<T, T>(
|
binary_op<T, T>(
|
||||||
a,
|
a,
|
||||||
@@ -444,8 +569,7 @@ void binary_op(
|
|||||||
binary_op<T, T>(
|
binary_op<T, T>(
|
||||||
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
|
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||||
value) {
|
|
||||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
// opvs and opvv were UseDefaultBinaryOp
|
// opvs and opvv were UseDefaultBinaryOp
|
||||||
binary_op<T, T>(
|
binary_op<T, T>(
|
||||||
@@ -461,8 +585,7 @@ void binary_op(
|
|||||||
binary_op<T, T>(
|
binary_op<T, T>(
|
||||||
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
value) {
|
|
||||||
// opvv was UseDefaultBinaryOp
|
// opvv was UseDefaultBinaryOp
|
||||||
binary_op<T, T>(
|
binary_op<T, T>(
|
||||||
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
||||||
|
@@ -9,43 +9,168 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T, typename U, typename Op, int D>
|
template <typename T, typename U, typename Op>
|
||||||
void binary_op_dims(
|
void binary_op_dims1(
|
||||||
const T* a,
|
const array& a,
|
||||||
const T* b,
|
const array& b,
|
||||||
U* out_a,
|
array& out_a,
|
||||||
U* out_b,
|
array& out_b,
|
||||||
Op op,
|
Op op) {
|
||||||
const std::vector<int>& shape,
|
const T* a_ptr = a.data<T>();
|
||||||
const std::vector<size_t>& a_strides,
|
const T* b_ptr = b.data<T>();
|
||||||
const std::vector<size_t>& b_strides,
|
U* dst_a = out_a.data<U>();
|
||||||
const std::vector<size_t>& out_strides,
|
U* dst_b = out_b.data<U>();
|
||||||
int axis) {
|
size_t a_idx = 0;
|
||||||
auto stride_a = a_strides[axis];
|
size_t b_idx = 0;
|
||||||
auto stride_b = b_strides[axis];
|
for (size_t i = 0; i < out_a.size(); ++i) {
|
||||||
auto stride_out = out_strides[axis];
|
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
auto N = shape[axis];
|
dst_a[i] = dst.first;
|
||||||
|
dst_b[i] = dst.second;
|
||||||
|
a_idx += a.strides()[0];
|
||||||
|
b_idx += b.strides()[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < N; i++) {
|
template <typename T, typename U, typename Op>
|
||||||
if constexpr (D > 1) {
|
void binary_op_dims1(
|
||||||
binary_op_dims<T, U, Op, D - 1>(
|
const array& a,
|
||||||
a,
|
const array& b,
|
||||||
b,
|
array& out_a,
|
||||||
out_a,
|
array& out_b,
|
||||||
out_b,
|
Op op,
|
||||||
op,
|
int stride) {
|
||||||
shape,
|
const T* a_ptr = a.data<T>();
|
||||||
a_strides,
|
const T* b_ptr = b.data<T>();
|
||||||
b_strides,
|
U* dst_a = out_a.data<U>();
|
||||||
out_strides,
|
U* dst_b = out_b.data<U>();
|
||||||
axis + 1);
|
size_t a_idx = 0;
|
||||||
} else {
|
size_t b_idx = 0;
|
||||||
std::tie(*out_a, *out_b) = op(*a, *b);
|
for (size_t i = 0; i < a.shape()[0]; i++) {
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||||
|
a_idx += a.strides()[0];
|
||||||
|
b_idx += b.strides()[0];
|
||||||
|
dst_a += stride;
|
||||||
|
dst_b += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims2(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
dst_a[out_idx] = dst.first;
|
||||||
|
dst_b[out_idx++] = dst.second;
|
||||||
|
a_idx += a.strides()[1];
|
||||||
|
b_idx += b.strides()[1];
|
||||||
}
|
}
|
||||||
a += stride_a;
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
b += stride_b;
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
out_a += stride_out;
|
}
|
||||||
out_b += stride_out;
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims2(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op,
|
||||||
|
int stride) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||||
|
a_idx += a.strides()[1];
|
||||||
|
b_idx += b.strides()[1];
|
||||||
|
dst_a += stride;
|
||||||
|
dst_b += stride;
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims3(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||||
|
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
dst_a[out_idx] = dst.first;
|
||||||
|
dst_b[out_idx++] = dst.second;
|
||||||
|
a_idx += a.strides()[2];
|
||||||
|
b_idx += b.strides()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||||
|
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dims4(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op) {
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||||
|
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||||
|
auto dst = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
|
dst_a[out_idx] = dst.first;
|
||||||
|
dst_b[out_idx++] = dst.second;
|
||||||
|
a_idx += a.strides()[3];
|
||||||
|
b_idx += b.strides()[3];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||||
|
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||||
|
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,160 +181,352 @@ void binary_op_dispatch_dims(
|
|||||||
array& out_a,
|
array& out_a,
|
||||||
array& out_b,
|
array& out_b,
|
||||||
Op op) {
|
Op op) {
|
||||||
auto [shape, strides] = collapse_contiguous_dims(
|
switch (out_a.ndim()) {
|
||||||
a.shape(), {a.strides(), b.strides(), out_a.strides()});
|
|
||||||
const auto& a_strides = strides[0];
|
|
||||||
const auto& b_strides = strides[1];
|
|
||||||
const auto& out_strides = strides[2];
|
|
||||||
const T* a_ptr = a.data<T>();
|
|
||||||
const T* b_ptr = b.data<T>();
|
|
||||||
U* out_a_ptr = out_a.data<U>();
|
|
||||||
U* out_b_ptr = out_b.data<U>();
|
|
||||||
|
|
||||||
int ndim = shape.size();
|
|
||||||
switch (ndim) {
|
|
||||||
case 1:
|
case 1:
|
||||||
binary_op_dims<T, U, Op, 1>(
|
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op);
|
||||||
a_ptr,
|
|
||||||
b_ptr,
|
|
||||||
out_a_ptr,
|
|
||||||
out_b_ptr,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
out_strides,
|
|
||||||
0);
|
|
||||||
return;
|
return;
|
||||||
case 2:
|
case 2:
|
||||||
binary_op_dims<T, U, Op, 2>(
|
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op);
|
||||||
a_ptr,
|
return;
|
||||||
b_ptr,
|
case 3:
|
||||||
out_a_ptr,
|
binary_op_dims3<T, U, Op>(a, b, out_a, out_b, op);
|
||||||
out_b_ptr,
|
return;
|
||||||
op,
|
case 4:
|
||||||
shape,
|
binary_op_dims4<T, U, Op>(a, b, out_a, out_b, op);
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
out_strides,
|
|
||||||
0);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
|
const T* a_ptr = a.data<T>();
|
||||||
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
|
const T* b_ptr = b.data<T>();
|
||||||
size_t stride = out_strides[ndim - 3];
|
U* dst_a = out_a.data<U>();
|
||||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
U* dst_b = out_b.data<U>();
|
||||||
binary_op_dims<T, U, Op, 2>(
|
for (size_t i = 0; i < out_a.size(); i++) {
|
||||||
a_ptr + a_it.loc,
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
b_ptr + b_it.loc,
|
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||||
out_a_ptr + elem,
|
std::tie(dst_a[i], dst_b[i]) = op(a_ptr[a_idx], b_ptr[b_idx]);
|
||||||
out_b_ptr + elem,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
out_strides,
|
|
||||||
ndim - 2);
|
|
||||||
a_it.step();
|
|
||||||
b_it.step();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U = T, typename Op>
|
template <typename T, typename U, typename Op>
|
||||||
|
void binary_op_dispatch_dims(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op,
|
||||||
|
int dim,
|
||||||
|
int stride) {
|
||||||
|
// Number of dimensions to loop over for vectorized ops
|
||||||
|
switch (dim) {
|
||||||
|
case 1:
|
||||||
|
binary_op_dims1<T, U, Op>(a, b, out_a, out_b, op, stride);
|
||||||
|
return;
|
||||||
|
case 2:
|
||||||
|
binary_op_dims2<T, U, Op>(a, b, out_a, out_b, op, stride);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T* a_ptr = a.data<T>();
|
||||||
|
const T* b_ptr = b.data<T>();
|
||||||
|
U* dst_a = out_a.data<U>();
|
||||||
|
U* dst_b = out_b.data<U>();
|
||||||
|
for (size_t i = 0; i < out_a.size(); i += stride) {
|
||||||
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
|
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||||
|
op(a_ptr + a_idx, b_ptr + b_idx, dst_a, dst_b, stride);
|
||||||
|
dst_a += stride;
|
||||||
|
dst_b += stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
typename U,
|
||||||
|
typename Op,
|
||||||
|
typename OpSV,
|
||||||
|
typename OpVS,
|
||||||
|
typename OpVV>
|
||||||
|
void binary_op(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
array& out_a,
|
||||||
|
array& out_b,
|
||||||
|
Op op,
|
||||||
|
OpSV opsv,
|
||||||
|
OpVS opvs,
|
||||||
|
OpVV opvv) {
|
||||||
|
auto bopt = get_binary_op_type(a, b);
|
||||||
|
set_binary_op_output_data(a, b, out_a, bopt);
|
||||||
|
set_binary_op_output_data(a, b, out_b, bopt);
|
||||||
|
|
||||||
|
// The full computation is scalar scalar so call the base op once
|
||||||
|
if (bopt == BinaryOpType::ScalarScalar) {
|
||||||
|
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
|
||||||
|
op(*a.data<T>(), *b.data<T>());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The full computation is scalar vector so delegate to the op
|
||||||
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
|
opsv(
|
||||||
|
a.data<T>(),
|
||||||
|
b.data<T>(),
|
||||||
|
out_a.data<U>(),
|
||||||
|
out_b.data<U>(),
|
||||||
|
b.data_size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The full computation is vector scalar so delegate to the op
|
||||||
|
if (bopt == BinaryOpType::VectorScalar) {
|
||||||
|
opvs(
|
||||||
|
a.data<T>(),
|
||||||
|
b.data<T>(),
|
||||||
|
out_a.data<U>(),
|
||||||
|
out_b.data<U>(),
|
||||||
|
a.data_size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The full computation is vector vector so delegate to the op
|
||||||
|
if (bopt == BinaryOpType::VectorVector) {
|
||||||
|
opvv(
|
||||||
|
a.data<T>(),
|
||||||
|
b.data<T>(),
|
||||||
|
out_a.data<U>(),
|
||||||
|
out_b.data<U>(),
|
||||||
|
out_a.size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// General computation so let's try to optimize
|
||||||
|
|
||||||
|
// Get the left-most dim such that the array is row contiguous after
|
||||||
|
auto& strides = out_a.strides();
|
||||||
|
auto leftmost_rc_dim = [&strides](const array& arr) {
|
||||||
|
int d = arr.ndim() - 1;
|
||||||
|
for (; d >= 0 && arr.strides()[d] == strides[d]; d--) {
|
||||||
|
}
|
||||||
|
return d + 1;
|
||||||
|
};
|
||||||
|
auto a_rc_dim = leftmost_rc_dim(a);
|
||||||
|
auto b_rc_dim = leftmost_rc_dim(b);
|
||||||
|
|
||||||
|
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||||
|
auto leftmost_s_dim = [](const array& arr) {
|
||||||
|
int d = arr.ndim() - 1;
|
||||||
|
for (; d >= 0 && arr.strides()[d] == 0; d--) {
|
||||||
|
}
|
||||||
|
return d + 1;
|
||||||
|
};
|
||||||
|
auto a_s_dim = leftmost_s_dim(a);
|
||||||
|
auto b_s_dim = leftmost_s_dim(b);
|
||||||
|
|
||||||
|
auto ndim = out_a.ndim();
|
||||||
|
|
||||||
|
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||||
|
int dim = ndim;
|
||||||
|
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||||
|
bopt = BinaryOpType::VectorVector;
|
||||||
|
dim = d;
|
||||||
|
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||||
|
// contiguous
|
||||||
|
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||||
|
bopt = BinaryOpType::VectorScalar;
|
||||||
|
dim = d;
|
||||||
|
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||||
|
// contiguous
|
||||||
|
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||||
|
bopt = BinaryOpType::ScalarVector;
|
||||||
|
dim = d;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||||
|
// contiguous methods above. Except for the case that the flags do not
|
||||||
|
// correspond to the underlying contiguity.
|
||||||
|
size_t stride;
|
||||||
|
if (dim == 0 || strides[dim - 1] < 16) {
|
||||||
|
stride = 1;
|
||||||
|
bopt = BinaryOpType::General;
|
||||||
|
dim = ndim;
|
||||||
|
} else {
|
||||||
|
stride = strides[dim - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (bopt) {
|
||||||
|
case BinaryOpType::VectorVector:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
|
||||||
|
break;
|
||||||
|
case BinaryOpType::VectorScalar:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
|
||||||
|
break;
|
||||||
|
case BinaryOpType::ScalarVector:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||||
|
void binary_op(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
std::vector<array>& outputs,
|
||||||
|
Op op,
|
||||||
|
OpSV opsv,
|
||||||
|
OpVS opvs,
|
||||||
|
OpVV opvv) {
|
||||||
|
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||||
|
// with template specializations and overloading. Would it be simpler?
|
||||||
|
|
||||||
|
if (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||||
|
if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||||
|
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// opsv and opvs were UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
opvv);
|
||||||
|
}
|
||||||
|
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// opsv and opvv were UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
opvs,
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// opsv was UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
DefaultScalarVector<T, T, Op>(op),
|
||||||
|
opvs,
|
||||||
|
opvv);
|
||||||
|
}
|
||||||
|
} else if (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||||
|
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// opvs and opvv were UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
opsv,
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// opvs was UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
opsv,
|
||||||
|
DefaultVectorScalar<T, T, Op>(op),
|
||||||
|
opvv);
|
||||||
|
}
|
||||||
|
} else if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||||
|
// opvv was UseDefaultBinaryOp
|
||||||
|
binary_op<T, T>(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
op,
|
||||||
|
opsv,
|
||||||
|
opvs,
|
||||||
|
DefaultVectorVector<T, T, Op>(op));
|
||||||
|
} else {
|
||||||
|
// All ops provided
|
||||||
|
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Op>
|
||||||
void binary_op(
|
void binary_op(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
Op op) {
|
Op op) {
|
||||||
auto bopt = get_binary_op_type(a, b);
|
DefaultScalarVector<T, T, Op> opsv(op);
|
||||||
auto& out_a = outputs[0];
|
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||||
auto& out_b = outputs[1];
|
DefaultVectorVector<T, T, Op> opvv(op);
|
||||||
set_binary_op_output_data(a, b, out_a, bopt);
|
binary_op<T, T>(a, b, outputs[0], outputs[1], op, opsv, opvs, opvv);
|
||||||
set_binary_op_output_data(a, b, out_b, bopt);
|
|
||||||
|
|
||||||
// The full computation is scalar scalar so call the base op once
|
|
||||||
if (bopt == BinaryOpType::General) {
|
|
||||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto a_ptr = a.data<T>();
|
|
||||||
auto b_ptr = b.data<T>();
|
|
||||||
auto out_a_ptr = out_a.data<U>();
|
|
||||||
auto out_b_ptr = out_b.data<U>();
|
|
||||||
if (bopt == BinaryOpType::ScalarScalar) {
|
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
|
||||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
|
||||||
for (size_t i = 0; i < b.size(); ++i) {
|
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
|
||||||
out_a_ptr++;
|
|
||||||
out_b_ptr++;
|
|
||||||
b_ptr++;
|
|
||||||
}
|
|
||||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
|
||||||
for (size_t i = 0; i < a.size(); ++i) {
|
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
|
||||||
out_a_ptr++;
|
|
||||||
out_b_ptr++;
|
|
||||||
a_ptr++;
|
|
||||||
}
|
|
||||||
} else { // VectorVector
|
|
||||||
for (size_t i = 0; i < a.size(); ++i) {
|
|
||||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
|
||||||
out_a_ptr++;
|
|
||||||
out_b_ptr++;
|
|
||||||
a_ptr++;
|
|
||||||
b_ptr++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Op>
|
template <typename... Ops>
|
||||||
void binary(
|
void binary(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
Op op) {
|
Ops... ops) {
|
||||||
switch (outputs[0].dtype()) {
|
switch (outputs[0].dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
binary_op<bool>(a, b, outputs, op);
|
binary_op<bool>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case uint8:
|
case uint8:
|
||||||
binary_op<uint8_t>(a, b, outputs, op);
|
binary_op<uint8_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case uint16:
|
case uint16:
|
||||||
binary_op<uint16_t>(a, b, outputs, op);
|
binary_op<uint16_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case uint32:
|
case uint32:
|
||||||
binary_op<uint32_t>(a, b, outputs, op);
|
binary_op<uint32_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case uint64:
|
case uint64:
|
||||||
binary_op<uint64_t>(a, b, outputs, op);
|
binary_op<uint64_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case int8:
|
case int8:
|
||||||
binary_op<int8_t>(a, b, outputs, op);
|
binary_op<int8_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case int16:
|
case int16:
|
||||||
binary_op<int16_t>(a, b, outputs, op);
|
binary_op<int16_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
binary_op<int32_t>(a, b, outputs, op);
|
binary_op<int32_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
binary_op<int64_t>(a, b, outputs, op);
|
binary_op<int64_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
binary_op<float16_t>(a, b, outputs, op);
|
binary_op<float16_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case float32:
|
case float32:
|
||||||
binary_op<float>(a, b, outputs, op);
|
binary_op<float>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
binary_op<bfloat16_t>(a, b, outputs, op);
|
binary_op<bfloat16_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
case complex64:
|
case complex64:
|
||||||
binary_op<complex64_t>(a, b, outputs, op);
|
binary_op<complex64_t>(a, b, outputs, ops...);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -2,12 +2,46 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/lapack.h"
|
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
|
#include <Accelerate/Accelerate.h>
|
||||||
|
#else
|
||||||
|
#include <lapack.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Delegate to the Cholesky factorization taking into account differences in
|
||||||
|
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
|
||||||
|
int spotrf_wrapper(char uplo, float* matrix, int N) {
|
||||||
|
int info;
|
||||||
|
|
||||||
|
#ifdef LAPACK_FORTRAN_STRLEN_END
|
||||||
|
spotrf_(
|
||||||
|
/* uplo = */ &uplo,
|
||||||
|
/* n = */ &N,
|
||||||
|
/* a = */ matrix,
|
||||||
|
/* lda = */ &N,
|
||||||
|
/* info = */ &info,
|
||||||
|
/* uplo_len = */ static_cast<size_t>(1));
|
||||||
|
#else
|
||||||
|
spotrf_(
|
||||||
|
/* uplo = */ &uplo,
|
||||||
|
/* n = */ &N,
|
||||||
|
/* a = */ matrix,
|
||||||
|
/* lda = */ &N,
|
||||||
|
/* info = */ &info);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||||
// the matrix should be symmetric:
|
// the matrix should be symmetric:
|
||||||
@@ -32,14 +66,7 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
|
|||||||
|
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
// Compute Cholesky factorization.
|
// Compute Cholesky factorization.
|
||||||
int info;
|
int info = spotrf_wrapper(uplo, matrix, N);
|
||||||
MLX_LAPACK_FUNC(spotrf)
|
|
||||||
(
|
|
||||||
/* uplo = */ &uplo,
|
|
||||||
/* n = */ &N,
|
|
||||||
/* a = */ matrix,
|
|
||||||
/* lda = */ &N,
|
|
||||||
/* info = */ &info);
|
|
||||||
|
|
||||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||||
// because throwing an error would result in a crash. If we figure out how
|
// because throwing an error would result in a crash. If we figure out how
|
||||||
|
@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
out.copy_shared_buffer(inputs[0]);
|
out.copy_shared_buffer(inputs[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CustomTransforms::eval(
|
void CustomVJP::eval(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
assert(inputs.size() > outputs.size());
|
assert(inputs.size() > outputs.size());
|
||||||
@@ -156,7 +156,8 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Firstly let's collapse all the contiguous dimensions of the input
|
// Firstly let's collapse all the contiguous dimensions of the input
|
||||||
auto [shape, strides] = collapse_contiguous_dims(in);
|
auto [shape, _strides] = collapse_contiguous_dims(in);
|
||||||
|
auto& strides = _strides[0];
|
||||||
|
|
||||||
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
// If shapes fit exactly in the contiguous dims then no copy is necessary so
|
||||||
// let's check.
|
// let's check.
|
||||||
|
@@ -18,8 +18,7 @@ void print_constant(std::ostream& os, const array& x) {
|
|||||||
case complex64:
|
case complex64:
|
||||||
return print_complex_constant<complex64_t>(os, x);
|
return print_complex_constant<complex64_t>(os, x);
|
||||||
case int8:
|
case int8:
|
||||||
os << static_cast<int32_t>(x.item<int8_t>());
|
return print_int_constant<int8_t>(os, x);
|
||||||
return;
|
|
||||||
case int16:
|
case int16:
|
||||||
return print_int_constant<int16_t>(os, x);
|
return print_int_constant<int16_t>(os, x);
|
||||||
case int32:
|
case int32:
|
||||||
@@ -27,8 +26,7 @@ void print_constant(std::ostream& os, const array& x) {
|
|||||||
case int64:
|
case int64:
|
||||||
return print_int_constant<int64_t>(os, x);
|
return print_int_constant<int64_t>(os, x);
|
||||||
case uint8:
|
case uint8:
|
||||||
os << static_cast<uint32_t>(x.item<uint8_t>());
|
return print_int_constant<uint8_t>(os, x);
|
||||||
return;
|
|
||||||
case uint16:
|
case uint16:
|
||||||
return print_int_constant<uint16_t>(os, x);
|
return print_int_constant<uint16_t>(os, x);
|
||||||
case uint32:
|
case uint32:
|
||||||
@@ -207,8 +205,8 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Correct size
|
// - Correct size
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.is_donatable() &&
|
||||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||||
if (move_buffers) {
|
if (move_buffers) {
|
||||||
outputs[o].move_shared_buffer(
|
outputs[o].move_shared_buffer(
|
||||||
|
@@ -2,10 +2,7 @@
|
|||||||
|
|
||||||
#include <dlfcn.h>
|
#include <dlfcn.h>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <fstream>
|
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <mutex>
|
|
||||||
#include <shared_mutex>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
#include "mlx/backend/common/compiled_preamble.h"
|
#include "mlx/backend/common/compiled_preamble.h"
|
||||||
@@ -14,30 +11,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
struct CompilerCache {
|
|
||||||
struct DLib {
|
|
||||||
DLib(const std::string& libname) {
|
|
||||||
lib = dlopen(libname.c_str(), RTLD_NOW);
|
|
||||||
if (!lib) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "Could not load C++ shared library " << dlerror();
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
~DLib() {
|
|
||||||
dlclose(lib);
|
|
||||||
}
|
|
||||||
void* lib;
|
|
||||||
};
|
|
||||||
// Statics to cache compiled libraries and functions
|
|
||||||
std::list<DLib> libs;
|
|
||||||
std::unordered_map<std::string, void*> kernels;
|
|
||||||
std::shared_mutex mtx;
|
|
||||||
};
|
|
||||||
|
|
||||||
static CompilerCache cache{};
|
|
||||||
|
|
||||||
// GPU compile is always available if the GPU is available and since we are in
|
// GPU compile is always available if the GPU is available and since we are in
|
||||||
// this file CPU compile is also available.
|
// this file CPU compile is also available.
|
||||||
namespace detail {
|
namespace detail {
|
||||||
@@ -53,19 +26,32 @@ std::string get_temp_file(const std::string& name) {
|
|||||||
// Return a pointer to a compiled function
|
// Return a pointer to a compiled function
|
||||||
void* compile(
|
void* compile(
|
||||||
const std::string& kernel_name,
|
const std::string& kernel_name,
|
||||||
const std::function<std::string(void)>& source_builder) {
|
const std::string& source_code = "") {
|
||||||
{
|
struct DLib {
|
||||||
std::shared_lock lock(cache.mtx);
|
DLib(const std::string& libname) {
|
||||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
lib = dlopen(libname.c_str(), RTLD_NOW);
|
||||||
return it->second;
|
if (!lib) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Could not load C++ shared library " << dlerror();
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_lock lock(cache.mtx);
|
~DLib() {
|
||||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
dlclose(lib);
|
||||||
|
}
|
||||||
|
void* lib;
|
||||||
|
};
|
||||||
|
// Statics to cache compiled libraries and functions
|
||||||
|
static std::list<DLib> libs;
|
||||||
|
static std::unordered_map<std::string, void*> kernels;
|
||||||
|
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
std::string source_code = source_builder();
|
if (source_code.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
std::string kernel_file_name;
|
std::string kernel_file_name;
|
||||||
|
|
||||||
// Deal with long kernel names. Maximum length for files on macOS is 255
|
// Deal with long kernel names. Maximum length for files on macOS is 255
|
||||||
@@ -103,8 +89,8 @@ void* compile(
|
|||||||
source_file.close();
|
source_file.close();
|
||||||
|
|
||||||
std::ostringstream build_command;
|
std::ostringstream build_command;
|
||||||
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
|
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
|
||||||
<< source_file_path << "' -o '" << shared_lib_path << "'";
|
<< source_file_path << " -o " << shared_lib_path;
|
||||||
std::string build_command_str = build_command.str();
|
std::string build_command_str = build_command.str();
|
||||||
auto return_code = system(build_command_str.c_str());
|
auto return_code = system(build_command_str.c_str());
|
||||||
if (return_code) {
|
if (return_code) {
|
||||||
@@ -116,10 +102,10 @@ void* compile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// load library
|
// load library
|
||||||
cache.libs.emplace_back(shared_lib_path);
|
libs.emplace_back(shared_lib_path);
|
||||||
|
|
||||||
// Load function
|
// Load function
|
||||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
|
||||||
if (!fun) {
|
if (!fun) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||||
@@ -127,7 +113,7 @@ void* compile(
|
|||||||
<< dlerror();
|
<< dlerror();
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
cache.kernels.insert({kernel_name, fun});
|
kernels.insert({kernel_name, fun});
|
||||||
return fun;
|
return fun;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -329,7 +315,10 @@ void Compiled::eval_cpu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the function
|
// Get the function
|
||||||
auto fn_ptr = compile(kernel_name, [&]() {
|
auto fn_ptr = compile(kernel_name);
|
||||||
|
|
||||||
|
// If it doesn't exist, compile it
|
||||||
|
if (fn_ptr == nullptr) {
|
||||||
std::ostringstream kernel;
|
std::ostringstream kernel;
|
||||||
kernel << get_kernel_preamble() << std::endl;
|
kernel << get_kernel_preamble() << std::endl;
|
||||||
kernel << "extern \"C\" {" << std::endl;
|
kernel << "extern \"C\" {" << std::endl;
|
||||||
@@ -344,8 +333,10 @@ void Compiled::eval_cpu(
|
|||||||
ndim);
|
ndim);
|
||||||
// Close extern "C"
|
// Close extern "C"
|
||||||
kernel << "}" << std::endl;
|
kernel << "}" << std::endl;
|
||||||
return kernel.str();
|
|
||||||
});
|
// Compile and get function pointer
|
||||||
|
fn_ptr = compile(kernel_name, kernel.str());
|
||||||
|
}
|
||||||
|
|
||||||
compiled_allocate_outputs(
|
compiled_allocate_outputs(
|
||||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||||
|
@@ -3,8 +3,13 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
|
#include <Accelerate/Accelerate.h>
|
||||||
|
#else
|
||||||
|
#include <cblas.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/lapack.h"
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
@@ -679,32 +684,6 @@ void dispatch_slow_conv_3D(
|
|||||||
// Explicit gemm conv
|
// Explicit gemm conv
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void flip_spatial_dims_inplace(array& wt) {
|
|
||||||
T* x = wt.data<T>();
|
|
||||||
size_t out_channels = wt.shape(0);
|
|
||||||
size_t in_channels = wt.shape(-1);
|
|
||||||
|
|
||||||
// Calculate the total size of the spatial dimensions
|
|
||||||
int spatial_size = 1;
|
|
||||||
for (int d = 1; d < wt.ndim() - 1; ++d) {
|
|
||||||
spatial_size *= wt.shape(d);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < out_channels; i++) {
|
|
||||||
T* top = x + i * spatial_size * in_channels;
|
|
||||||
T* bottom =
|
|
||||||
x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels;
|
|
||||||
for (size_t j = 0; j < spatial_size / 2; j++) {
|
|
||||||
for (size_t k = 0; k < in_channels; k++) {
|
|
||||||
std::swap(top[k], bottom[k]);
|
|
||||||
}
|
|
||||||
top += in_channels;
|
|
||||||
bottom -= in_channels;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void explicit_gemm_conv_1D_cpu(
|
void explicit_gemm_conv_1D_cpu(
|
||||||
const array& in,
|
const array& in,
|
||||||
const array& wt,
|
const array& wt,
|
||||||
@@ -931,8 +910,7 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
array out,
|
array out,
|
||||||
const std::vector<int>& padding,
|
const std::vector<int>& padding,
|
||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation) {
|
||||||
const bool flip) {
|
|
||||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
const auto iDim = std::vector<int>(
|
const auto iDim = std::vector<int>(
|
||||||
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
|
||||||
@@ -1022,14 +1000,6 @@ void explicit_gemm_conv_ND_cpu(
|
|||||||
copy(wt, gemm_wt, ctype);
|
copy(wt, gemm_wt, ctype);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (flip) {
|
|
||||||
auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
|
|
||||||
copy(gemm_wt, gemm_wt_, CopyType::Vector);
|
|
||||||
|
|
||||||
flip_spatial_dims_inplace<float>(gemm_wt_);
|
|
||||||
gemm_wt = gemm_wt_;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (out.dtype() != float32) {
|
if (out.dtype() != float32) {
|
||||||
gemm_out = array(out.shape(), float32, nullptr, {});
|
gemm_out = array(out.shape(), float32, nullptr, {});
|
||||||
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
|
||||||
@@ -1072,15 +1042,10 @@ void conv_1D_cpu(
|
|||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
bool flip) {
|
bool flip) {
|
||||||
const int groups = in.shape().back() / wt.shape().back();
|
|
||||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
|
||||||
return explicit_gemm_conv_1D_cpu(
|
return explicit_gemm_conv_1D_cpu(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation);
|
in, wt, out, padding, wt_strides, wt_dilation);
|
||||||
}
|
}
|
||||||
if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
|
|
||||||
return explicit_gemm_conv_ND_cpu(
|
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, flip);
|
|
||||||
}
|
|
||||||
|
|
||||||
return dispatch_slow_conv_1D(
|
return dispatch_slow_conv_1D(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||||
@@ -1095,13 +1060,6 @@ void conv_2D_cpu(
|
|||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
bool flip) {
|
bool flip) {
|
||||||
const int groups = in.shape().back() / wt.shape().back();
|
|
||||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
|
|
||||||
in_dilation[1] == 1 && groups == 1) {
|
|
||||||
return explicit_gemm_conv_ND_cpu(
|
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, flip);
|
|
||||||
}
|
|
||||||
|
|
||||||
return dispatch_slow_conv_2D(
|
return dispatch_slow_conv_2D(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||||
}
|
}
|
||||||
@@ -1115,14 +1073,6 @@ void conv_3D_cpu(
|
|||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
bool flip) {
|
bool flip) {
|
||||||
const int groups = in.shape().back() / wt.shape().back();
|
|
||||||
if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 &&
|
|
||||||
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
|
|
||||||
groups == 1) {
|
|
||||||
return explicit_gemm_conv_ND_cpu(
|
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, flip);
|
|
||||||
}
|
|
||||||
|
|
||||||
return dispatch_slow_conv_3D(
|
return dispatch_slow_conv_3D(
|
||||||
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
|
||||||
}
|
}
|
||||||
@@ -1175,7 +1125,7 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
else {
|
else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[Convolution::eval] Convolution currently only supports"
|
msg << "[Convolution::eval] Convolution currently only supports"
|
||||||
<< " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2
|
<< " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2
|
||||||
<< " spatial dimensions";
|
<< " spatial dimensions";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
@@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@@ -26,117 +25,252 @@ void copy_vector(const array& src, array& dst) {
|
|||||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT, typename StrideT, int D>
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
inline void copy_dims(
|
void copy_general_dim1(
|
||||||
const SrcT* src,
|
const array& src,
|
||||||
DstT* dst,
|
array& dst,
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& data_shape,
|
||||||
const std::vector<StrideT>& i_strides,
|
const std::vector<stride_t>& i_strides,
|
||||||
const std::vector<StrideT>& o_strides,
|
int64_t i_offset) {
|
||||||
int axis) {
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
auto stride_src = i_strides[axis];
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
auto stride_dst = o_strides[axis];
|
stride_t src_idx = i_offset;
|
||||||
auto N = shape[axis];
|
stride_t dst_idx = 0;
|
||||||
|
for (int i = 0; i < data_shape[0]; ++i) {
|
||||||
for (int i = 0; i < N; i++) {
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
if constexpr (D > 1) {
|
src_idx += i_strides[0];
|
||||||
copy_dims<SrcT, DstT, StrideT, D - 1>(
|
|
||||||
src, dst, shape, i_strides, o_strides, axis + 1);
|
|
||||||
} else {
|
|
||||||
*dst = static_cast<DstT>(*src);
|
|
||||||
}
|
|
||||||
src += stride_src;
|
|
||||||
dst += stride_dst;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT, typename StrideT>
|
template <typename SrcT, typename DstT>
|
||||||
|
inline void copy_general_dim1(const array& src, array& dst) {
|
||||||
|
return copy_general_dim1<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
|
void copy_general_dim2(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
int64_t i_offset) {
|
||||||
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
|
stride_t src_idx = i_offset;
|
||||||
|
stride_t dst_idx = 0;
|
||||||
|
for (int i = 0; i < data_shape[0]; ++i) {
|
||||||
|
for (int j = 0; j < data_shape[1]; ++j) {
|
||||||
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
|
src_idx += i_strides[1];
|
||||||
|
}
|
||||||
|
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
inline void copy_general_dim2(const array& src, array& dst) {
|
||||||
|
return copy_general_dim2<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
|
void copy_general_dim3(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
int64_t i_offset) {
|
||||||
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
|
stride_t src_idx = i_offset;
|
||||||
|
stride_t dst_idx = 0;
|
||||||
|
for (int i = 0; i < data_shape[0]; ++i) {
|
||||||
|
for (int j = 0; j < data_shape[1]; ++j) {
|
||||||
|
for (int k = 0; k < data_shape[2]; ++k) {
|
||||||
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
|
src_idx += i_strides[2];
|
||||||
|
}
|
||||||
|
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
|
||||||
|
}
|
||||||
|
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
inline void copy_general_dim3(const array& src, array& dst) {
|
||||||
|
return copy_general_dim3<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
|
void copy_general_dim4(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
int64_t i_offset) {
|
||||||
|
const SrcT* src_ptr = src.data<SrcT>();
|
||||||
|
DstT* dst_ptr = dst.data<DstT>();
|
||||||
|
stride_t src_idx = i_offset;
|
||||||
|
stride_t dst_idx = 0;
|
||||||
|
for (int i = 0; i < data_shape[0]; ++i) {
|
||||||
|
for (int j = 0; j < data_shape[1]; ++j) {
|
||||||
|
for (int k = 0; k < data_shape[2]; ++k) {
|
||||||
|
for (int ii = 0; ii < data_shape[3]; ++ii) {
|
||||||
|
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
|
||||||
|
src_idx += i_strides[3];
|
||||||
|
}
|
||||||
|
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
|
||||||
|
}
|
||||||
|
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
|
||||||
|
}
|
||||||
|
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
inline void copy_general_dim4(const array& src, array& dst) {
|
||||||
|
return copy_general_dim4<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
|
void copy_general(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
int64_t i_offset) {
|
||||||
|
switch (src.ndim()) {
|
||||||
|
case 1:
|
||||||
|
copy_general_dim1<SrcT, DstT, stride_t>(
|
||||||
|
src, dst, data_shape, i_strides, i_offset);
|
||||||
|
return;
|
||||||
|
case 2:
|
||||||
|
copy_general_dim2<SrcT, DstT, stride_t>(
|
||||||
|
src, dst, data_shape, i_strides, i_offset);
|
||||||
|
return;
|
||||||
|
case 3:
|
||||||
|
copy_general_dim3<SrcT, DstT, stride_t>(
|
||||||
|
src, dst, data_shape, i_strides, i_offset);
|
||||||
|
return;
|
||||||
|
case 4:
|
||||||
|
copy_general_dim4<SrcT, DstT, stride_t>(
|
||||||
|
src, dst, data_shape, i_strides, i_offset);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||||
|
auto dst_ptr = dst.data<DstT>();
|
||||||
|
for (size_t i = 0; i < dst.size(); ++i) {
|
||||||
|
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
|
||||||
|
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT>
|
||||||
|
inline void copy_general(const array& src, array& dst) {
|
||||||
|
return copy_general<SrcT, DstT, size_t>(
|
||||||
|
src, dst, src.shape(), src.strides(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
|
inline void copy_general(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
const std::vector<stride_t>& o_strides,
|
||||||
|
int64_t i_offset,
|
||||||
|
int64_t o_offset) {
|
||||||
|
return copy_general<SrcT, DstT, stride_t>(
|
||||||
|
src, dst, data_shape, i_strides, i_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t, int D>
|
||||||
|
inline void copy_general_general_dims(
|
||||||
|
const array& src,
|
||||||
|
array& dst,
|
||||||
|
const std::vector<int>& data_shape,
|
||||||
|
const std::vector<stride_t>& i_strides,
|
||||||
|
const std::vector<stride_t>& o_strides,
|
||||||
|
stride_t i_offset,
|
||||||
|
stride_t o_offset) {
|
||||||
|
if constexpr (D > 1) {
|
||||||
|
int axis = src.ndim() - D;
|
||||||
|
auto stride_src = i_strides[axis];
|
||||||
|
auto stride_dst = o_strides[axis];
|
||||||
|
auto N = data_shape[axis];
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
|
i_offset += stride_src;
|
||||||
|
o_offset += stride_dst;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int axis = src.ndim() - 1;
|
||||||
|
auto stride_src = i_strides[axis];
|
||||||
|
auto stride_dst = o_strides[axis];
|
||||||
|
auto N = data_shape[axis];
|
||||||
|
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
|
||||||
|
DstT* dst_ptr = dst.data<DstT>() + o_offset;
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
*dst_ptr = static_cast<DstT>(*src_ptr);
|
||||||
|
src_ptr += stride_src;
|
||||||
|
dst_ptr += stride_dst;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SrcT, typename DstT, typename stride_t>
|
||||||
void copy_general_general(
|
void copy_general_general(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const std::vector<int>& data_shape,
|
const std::vector<int>& data_shape,
|
||||||
const std::vector<StrideT>& i_strides,
|
const std::vector<stride_t>& i_strides,
|
||||||
const std::vector<StrideT>& o_strides,
|
const std::vector<stride_t>& o_strides,
|
||||||
int64_t i_offset,
|
stride_t i_offset,
|
||||||
int64_t o_offset) {
|
stride_t o_offset) {
|
||||||
if (data_shape.empty()) {
|
switch (src.ndim()) {
|
||||||
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
|
case 1:
|
||||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
|
||||||
*dst_ptr = val;
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
return;
|
return;
|
||||||
|
case 2:
|
||||||
|
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
|
return;
|
||||||
|
case 3:
|
||||||
|
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
|
return;
|
||||||
|
case 4:
|
||||||
|
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
|
return;
|
||||||
|
case 5:
|
||||||
|
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||||
|
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
auto [shape, strides] = collapse_contiguous_dims(
|
|
||||||
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides});
|
int size = std::accumulate(
|
||||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
for (int i = 0; i < src.size(); i += size) {
|
||||||
int ndim = shape.size();
|
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
||||||
if (ndim == 1) {
|
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
||||||
copy_dims<SrcT, DstT, StrideT, 1>(
|
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
|
||||||
return;
|
|
||||||
} else if (ndim == 2) {
|
|
||||||
copy_dims<SrcT, DstT, StrideT, 2>(
|
|
||||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
|
||||||
return;
|
|
||||||
} else if (ndim == 3) {
|
|
||||||
copy_dims<SrcT, DstT, StrideT, 3>(
|
|
||||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3);
|
|
||||||
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3);
|
|
||||||
StrideT stride = std::accumulate(
|
|
||||||
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>());
|
|
||||||
for (StrideT elem = 0; elem < src.size(); elem += stride) {
|
|
||||||
copy_dims<SrcT, DstT, StrideT, 3>(
|
|
||||||
src_ptr + in.loc,
|
|
||||||
dst_ptr + out.loc,
|
|
||||||
shape,
|
|
||||||
strides[0],
|
|
||||||
strides[1],
|
|
||||||
ndim - 3);
|
|
||||||
in.step();
|
|
||||||
out.step();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
template <typename SrcT, typename DstT>
|
||||||
inline void copy_general_general(const array& src, array& dst) {
|
inline void copy_general_general(const array& src, array& dst) {
|
||||||
copy_general_general<SrcT, DstT, size_t>(
|
return copy_general_general<SrcT, DstT, size_t>(
|
||||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcT, typename DstT, typename StrideT>
|
|
||||||
void copy_general(
|
|
||||||
const array& src,
|
|
||||||
array& dst,
|
|
||||||
const std::vector<int>& data_shape,
|
|
||||||
const std::vector<StrideT>& i_strides,
|
|
||||||
const std::vector<StrideT>&,
|
|
||||||
int64_t i_offset,
|
|
||||||
int64_t o_offset) {
|
|
||||||
copy_general_general<SrcT, DstT, StrideT>(
|
|
||||||
src,
|
|
||||||
dst,
|
|
||||||
data_shape,
|
|
||||||
i_strides,
|
|
||||||
make_contiguous_strides<StrideT>(data_shape),
|
|
||||||
i_offset,
|
|
||||||
o_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SrcT, typename DstT>
|
|
||||||
inline void copy_general(const array& src, array& dst) {
|
|
||||||
copy_general_general<SrcT, DstT, size_t>(
|
|
||||||
src,
|
|
||||||
dst,
|
|
||||||
src.shape(),
|
|
||||||
src.strides(),
|
|
||||||
make_contiguous_strides<size_t>(src.shape()),
|
|
||||||
0,
|
|
||||||
0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename SrcT, typename DstT, typename... Args>
|
template <typename SrcT, typename DstT, typename... Args>
|
||||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
@@ -151,7 +285,6 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
|||||||
return;
|
return;
|
||||||
case CopyType::GeneralGeneral:
|
case CopyType::GeneralGeneral:
|
||||||
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -252,7 +385,7 @@ inline void copy_inplace_dispatch(
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
||||||
copy_inplace_dispatch(src, dst, ctype);
|
return copy_inplace_dispatch(src, dst, ctype);
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy(const array& src, array& dst, CopyType ctype) {
|
void copy(const array& src, array& dst, CopyType ctype) {
|
||||||
@@ -282,20 +415,20 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
|||||||
copy_inplace(src, dst, ctype);
|
copy_inplace(src, dst, ctype);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename StrideT>
|
template <typename stride_t>
|
||||||
void copy_inplace(
|
void copy_inplace(
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const std::vector<int>& data_shape,
|
const std::vector<int>& data_shape,
|
||||||
const std::vector<StrideT>& i_strides,
|
const std::vector<stride_t>& i_strides,
|
||||||
const std::vector<StrideT>& o_strides,
|
const std::vector<stride_t>& o_strides,
|
||||||
int64_t i_offset,
|
int64_t i_offset,
|
||||||
int64_t o_offset,
|
int64_t o_offset,
|
||||||
CopyType ctype) {
|
CopyType ctype) {
|
||||||
switch (ctype) {
|
switch (ctype) {
|
||||||
case CopyType::General:
|
case CopyType::General:
|
||||||
case CopyType::GeneralGeneral:
|
case CopyType::GeneralGeneral:
|
||||||
copy_inplace_dispatch(
|
return copy_inplace_dispatch(
|
||||||
src,
|
src,
|
||||||
dst,
|
dst,
|
||||||
ctype,
|
ctype,
|
||||||
@@ -304,24 +437,15 @@ void copy_inplace(
|
|||||||
o_strides,
|
o_strides,
|
||||||
i_offset,
|
i_offset,
|
||||||
o_offset);
|
o_offset);
|
||||||
break;
|
|
||||||
case CopyType::Scalar:
|
case CopyType::Scalar:
|
||||||
case CopyType::Vector:
|
case CopyType::Vector:
|
||||||
copy_inplace_dispatch(src, dst, ctype);
|
return copy_inplace_dispatch(src, dst, ctype);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template void copy_inplace<size_t>(
|
template <>
|
||||||
const array& src,
|
void copy_inplace<int64_t>(
|
||||||
array& dst,
|
|
||||||
const std::vector<int>& data_shape,
|
|
||||||
const std::vector<size_t>& i_strides,
|
|
||||||
const std::vector<size_t>& o_strides,
|
|
||||||
int64_t i_offset,
|
|
||||||
int64_t o_offset,
|
|
||||||
CopyType ctype);
|
|
||||||
|
|
||||||
template void copy_inplace<int64_t>(
|
|
||||||
const array& src,
|
const array& src,
|
||||||
array& dst,
|
array& dst,
|
||||||
const std::vector<int>& data_shape,
|
const std::vector<int>& data_shape,
|
||||||
@@ -329,6 +453,24 @@ template void copy_inplace<int64_t>(
|
|||||||
const std::vector<int64_t>& o_strides,
|
const std::vector<int64_t>& o_strides,
|
||||||
int64_t i_offset,
|
int64_t i_offset,
|
||||||
int64_t o_offset,
|
int64_t o_offset,
|
||||||
CopyType ctype);
|
CopyType ctype) {
|
||||||
|
switch (ctype) {
|
||||||
|
case CopyType::General:
|
||||||
|
case CopyType::GeneralGeneral:
|
||||||
|
return copy_inplace_dispatch(
|
||||||
|
src,
|
||||||
|
dst,
|
||||||
|
ctype,
|
||||||
|
data_shape,
|
||||||
|
i_strides,
|
||||||
|
o_strides,
|
||||||
|
i_offset,
|
||||||
|
o_offset);
|
||||||
|
|
||||||
|
case CopyType::Scalar:
|
||||||
|
case CopyType::Vector:
|
||||||
|
return copy_inplace_dispatch(src, dst, ctype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -1,10 +1,14 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
|
#include <Accelerate/Accelerate.h>
|
||||||
|
#else
|
||||||
|
#include <cblas.h>
|
||||||
|
#endif
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/lapack.h"
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
@@ -48,7 +52,7 @@ DEFAULT(Convolution)
|
|||||||
DEFAULT(Copy)
|
DEFAULT(Copy)
|
||||||
DEFAULT(Cos)
|
DEFAULT(Cos)
|
||||||
DEFAULT(Cosh)
|
DEFAULT(Cosh)
|
||||||
DEFAULT_MULTI(CustomTransforms)
|
DEFAULT_MULTI(CustomVJP)
|
||||||
DEFAULT_MULTI(Depends)
|
DEFAULT_MULTI(Depends)
|
||||||
DEFAULT(Divide)
|
DEFAULT(Divide)
|
||||||
DEFAULT(NumberOfElements)
|
DEFAULT(NumberOfElements)
|
||||||
@@ -64,7 +68,6 @@ DEFAULT(Full)
|
|||||||
DEFAULT(Gather)
|
DEFAULT(Gather)
|
||||||
DEFAULT(Greater)
|
DEFAULT(Greater)
|
||||||
DEFAULT(GreaterEqual)
|
DEFAULT(GreaterEqual)
|
||||||
DEFAULT(Hadamard)
|
|
||||||
DEFAULT(Less)
|
DEFAULT(Less)
|
||||||
DEFAULT(LessEqual)
|
DEFAULT(LessEqual)
|
||||||
DEFAULT(Load)
|
DEFAULT(Load)
|
||||||
@@ -110,7 +113,6 @@ DEFAULT(Tanh)
|
|||||||
DEFAULT(Transpose)
|
DEFAULT(Transpose)
|
||||||
DEFAULT(Inverse)
|
DEFAULT(Inverse)
|
||||||
DEFAULT(Cholesky)
|
DEFAULT(Cholesky)
|
||||||
DEFAULT_MULTI(Eigh)
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@@ -1,117 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/backend/common/copy.h"
|
|
||||||
#include "mlx/backend/common/lapack.h"
|
|
||||||
#include "mlx/linalg.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
void ssyevd(
|
|
||||||
char jobz,
|
|
||||||
char uplo,
|
|
||||||
float* a,
|
|
||||||
int N,
|
|
||||||
float* w,
|
|
||||||
float* work,
|
|
||||||
int lwork,
|
|
||||||
int* iwork,
|
|
||||||
int liwork) {
|
|
||||||
int info;
|
|
||||||
MLX_LAPACK_FUNC(ssyevd)
|
|
||||||
(
|
|
||||||
/* jobz = */ &jobz,
|
|
||||||
/* uplo = */ &uplo,
|
|
||||||
/* n = */ &N,
|
|
||||||
/* a = */ a,
|
|
||||||
/* lda = */ &N,
|
|
||||||
/* w = */ w,
|
|
||||||
/* work = */ work,
|
|
||||||
/* lwork = */ &lwork,
|
|
||||||
/* iwork = */ iwork,
|
|
||||||
/* liwork = */ &liwork,
|
|
||||||
/* info = */ &info);
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream msg;
|
|
||||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
|
||||||
<< info;
|
|
||||||
throw std::runtime_error(msg.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
|
||||||
const auto& a = inputs[0];
|
|
||||||
auto& values = outputs[0];
|
|
||||||
|
|
||||||
auto vectors = compute_eigenvectors_
|
|
||||||
? outputs[1]
|
|
||||||
: array(a.shape(), a.dtype(), nullptr, {});
|
|
||||||
|
|
||||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
|
||||||
|
|
||||||
copy(
|
|
||||||
a,
|
|
||||||
vectors,
|
|
||||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
|
||||||
|
|
||||||
if (compute_eigenvectors_) {
|
|
||||||
// Set the strides and flags so the eigenvectors
|
|
||||||
// are in the columns of the output
|
|
||||||
auto flags = vectors.flags();
|
|
||||||
auto strides = vectors.strides();
|
|
||||||
auto ndim = a.ndim();
|
|
||||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
|
||||||
|
|
||||||
if (a.size() > 1) {
|
|
||||||
flags.row_contiguous = false;
|
|
||||||
if (ndim > 2) {
|
|
||||||
flags.col_contiguous = false;
|
|
||||||
} else {
|
|
||||||
flags.col_contiguous = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto vec_ptr = vectors.data<float>();
|
|
||||||
auto eig_ptr = values.data<float>();
|
|
||||||
|
|
||||||
char jobz = compute_eigenvectors_ ? 'V' : 'N';
|
|
||||||
auto N = a.shape(-1);
|
|
||||||
|
|
||||||
// Work query
|
|
||||||
int lwork;
|
|
||||||
int liwork;
|
|
||||||
{
|
|
||||||
float work;
|
|
||||||
int iwork;
|
|
||||||
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
|
|
||||||
lwork = static_cast<int>(work);
|
|
||||||
liwork = iwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
|
||||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
|
||||||
for (size_t i = 0; i < a.size() / (N * N); ++i) {
|
|
||||||
ssyevd(
|
|
||||||
jobz,
|
|
||||||
uplo_[0],
|
|
||||||
vec_ptr,
|
|
||||||
N,
|
|
||||||
eig_ptr,
|
|
||||||
static_cast<float*>(work_buf.buffer.raw_ptr()),
|
|
||||||
lwork,
|
|
||||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
|
||||||
liwork);
|
|
||||||
vec_ptr += N * N;
|
|
||||||
eig_ptr += N;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@@ -1,107 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
#include "mlx/backend/common/copy.h"
|
|
||||||
#include "mlx/backend/common/hadamard.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
// n = 2^k component
|
|
||||||
template <typename T>
|
|
||||||
void hadamard_n(array& out, int n, int m, float scale) {
|
|
||||||
for (int b = 0; b < out.size() / n; b++) {
|
|
||||||
size_t loc = b * n;
|
|
||||||
T* data_ptr = out.data<T>() + loc;
|
|
||||||
int h = 1;
|
|
||||||
int n_over_2 = n / 2;
|
|
||||||
while (h < n) {
|
|
||||||
for (int i = 0; i < n / 2; i++) {
|
|
||||||
int k = i & (h - 1);
|
|
||||||
int j = ((i - k) << 1) + k;
|
|
||||||
float x = *(data_ptr + j);
|
|
||||||
float y = *(data_ptr + j + h);
|
|
||||||
*(data_ptr + j) = x + y;
|
|
||||||
*(data_ptr + j + h) = x - y;
|
|
||||||
if (h == n_over_2) {
|
|
||||||
*(data_ptr + j) *= scale;
|
|
||||||
*(data_ptr + j + h) *= scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h <<= 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// m component
|
|
||||||
template <typename T>
|
|
||||||
void hadamard_m(array& out, int n, int m, float scale) {
|
|
||||||
auto h_matrices = hadamard_matrices();
|
|
||||||
auto& matrix = h_matrices[m];
|
|
||||||
auto start = 1;
|
|
||||||
auto end = matrix.find('\n', start);
|
|
||||||
std::vector<bool> hmat_vec;
|
|
||||||
while (end != std::string_view::npos) {
|
|
||||||
auto row = matrix.substr(start, end - start);
|
|
||||||
for (int i = 0; i < row.length(); i++) {
|
|
||||||
hmat_vec.push_back(row[i] == '+');
|
|
||||||
}
|
|
||||||
start = end + 1;
|
|
||||||
end = matrix.find('\n', start);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int b = 0; b < out.size() / m / n; b++) {
|
|
||||||
size_t loc = b * n * m;
|
|
||||||
T* data_ptr = out.data<T>() + loc;
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
std::vector<float> out(m);
|
|
||||||
for (int j = 0; j < m; j++) {
|
|
||||||
for (int k = 0; k < m; k++) {
|
|
||||||
float x = *(data_ptr + i + k * n);
|
|
||||||
if (hmat_vec[k + j * m]) {
|
|
||||||
out[j] += x;
|
|
||||||
} else {
|
|
||||||
out[j] -= x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int j = 0; j < m; j++) {
|
|
||||||
*(data_ptr + i + j * n) = out[j] * scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void hadamard(array& out, int n, int m, float scale) {
|
|
||||||
float n_scale = m > 1 ? 1.0 : scale;
|
|
||||||
hadamard_n<T>(out, n, m, n_scale);
|
|
||||||
if (m > 1) {
|
|
||||||
hadamard_m<T>(out, n, m, scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 1);
|
|
||||||
auto& in = inputs[0];
|
|
||||||
|
|
||||||
// Copy input to output
|
|
||||||
copy(in, out, CopyType::General);
|
|
||||||
|
|
||||||
int axis = out.ndim() - 1;
|
|
||||||
auto [n, m] = decompose_hadamard(out.shape(axis));
|
|
||||||
|
|
||||||
switch (in.dtype()) {
|
|
||||||
case float32:
|
|
||||||
return hadamard<float>(out, n, m, scale_);
|
|
||||||
case float16:
|
|
||||||
return hadamard<float16_t>(out, n, m, scale_);
|
|
||||||
case bfloat16:
|
|
||||||
return hadamard<bfloat16_t>(out, n, m, scale_);
|
|
||||||
default:
|
|
||||||
throw std::invalid_argument("[hadamard] Unsupported type.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@@ -1,105 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "mlx/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
// From http://neilsloane.com/hadamard/
|
|
||||||
constexpr std::string_view h12 = R"(
|
|
||||||
+-++++++++++
|
|
||||||
--+-+-+-+-+-
|
|
||||||
+++-++----++
|
|
||||||
+---+--+-++-
|
|
||||||
+++++-++----
|
|
||||||
+-+---+--+-+
|
|
||||||
++--+++-++--
|
|
||||||
+--++---+--+
|
|
||||||
++----+++-++
|
|
||||||
+--+-++---+-
|
|
||||||
++++----+++-
|
|
||||||
+-+--+-++---
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view h20 = R"(
|
|
||||||
+----+----++--++-++-
|
|
||||||
-+----+---+++---+-++
|
|
||||||
--+----+---+++-+-+-+
|
|
||||||
---+----+---+++++-+-
|
|
||||||
----+----++--++-++-+
|
|
||||||
-+++++-----+--+++--+
|
|
||||||
+-+++-+---+-+--+++--
|
|
||||||
++-++--+---+-+--+++-
|
|
||||||
+++-+---+---+-+--+++
|
|
||||||
++++-----++--+-+--++
|
|
||||||
--++-+-++-+-----++++
|
|
||||||
---++-+-++-+---+-+++
|
|
||||||
+---++-+-+--+--++-++
|
|
||||||
++---++-+----+-+++-+
|
|
||||||
-++---++-+----+++++-
|
|
||||||
-+--+--++-+----+----
|
|
||||||
+-+-----++-+----+---
|
|
||||||
-+-+-+---+--+----+--
|
|
||||||
--+-+++------+----+-
|
|
||||||
+--+--++------+----+
|
|
||||||
)";
|
|
||||||
|
|
||||||
constexpr std::string_view h28 = R"(
|
|
||||||
+------++----++-+--+-+--++--
|
|
||||||
-+-----+++-----+-+--+-+--++-
|
|
||||||
--+-----+++---+-+-+----+--++
|
|
||||||
---+-----+++---+-+-+-+--+--+
|
|
||||||
----+-----+++---+-+-+++--+--
|
|
||||||
-----+-----++++--+-+--++--+-
|
|
||||||
------++----++-+--+-+--++--+
|
|
||||||
--++++-+-------++--+++-+--+-
|
|
||||||
---++++-+-----+-++--+-+-+--+
|
|
||||||
+---+++--+----++-++--+-+-+--
|
|
||||||
++---++---+----++-++--+-+-+-
|
|
||||||
+++---+----+----++-++--+-+-+
|
|
||||||
++++--------+-+--++-++--+-+-
|
|
||||||
-++++--------+++--++--+--+-+
|
|
||||||
-+-++-++--++--+--------++++-
|
|
||||||
+-+-++--+--++--+--------++++
|
|
||||||
-+-+-++--+--++--+----+---+++
|
|
||||||
+-+-+-++--+--+---+---++---++
|
|
||||||
++-+-+-++--+------+--+++---+
|
|
||||||
-++-+-+-++--+------+-++++---
|
|
||||||
+-++-+---++--+------+-++++--
|
|
||||||
-++--++-+-++-+++----++------
|
|
||||||
+-++--++-+-++-+++-----+-----
|
|
||||||
++-++---+-+-++-+++-----+----
|
|
||||||
-++-++-+-+-+-+--+++-----+---
|
|
||||||
--++-++++-+-+----+++-----+--
|
|
||||||
+--++-+-++-+-+----+++-----+-
|
|
||||||
++--++-+-++-+-+----++------+
|
|
||||||
)";
|
|
||||||
|
|
||||||
inline const std::map<int, std::string_view> hadamard_matrices() {
|
|
||||||
return {{12, h12}, {20, h20}, {28, h28}};
|
|
||||||
}
|
|
||||||
|
|
||||||
inline std::pair<int, int> decompose_hadamard(int n) {
|
|
||||||
// n = m*2^k
|
|
||||||
int m = 1;
|
|
||||||
if (!is_power_of_2(n)) {
|
|
||||||
auto h_matrices = hadamard_matrices();
|
|
||||||
for (auto [factor, _] : h_matrices) {
|
|
||||||
if (n % factor == 0) {
|
|
||||||
m = factor;
|
|
||||||
n /= factor;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (m == 1) {
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {n, m};
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@@ -1,4 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@@ -80,18 +81,11 @@ void gather(
|
|||||||
T* dst_ptr = out.data<T>();
|
T* dst_ptr = out.data<T>();
|
||||||
size_t out_idx = 0;
|
size_t out_idx = 0;
|
||||||
|
|
||||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
|
||||||
ContiguousIterator<size_t> src_it;
|
|
||||||
if (!can_copy && src.ndim() > 0) {
|
|
||||||
src_it = std::move(
|
|
||||||
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
|
|
||||||
}
|
|
||||||
for (int idx = 0; idx < ind_size; idx++) {
|
for (int idx = 0; idx < ind_size; idx++) {
|
||||||
size_t src_idx = 0;
|
size_t src_idx = 0;
|
||||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||||
auto ax = axes[ii];
|
auto ax = axes[ii];
|
||||||
auto idx_loc = its[ii].loc;
|
auto idx_loc = elem_to_loc(idx, inds[ii]);
|
||||||
its[ii].step();
|
|
||||||
auto idx_val =
|
auto idx_val =
|
||||||
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
||||||
src_idx += (idx_val * src.strides()[ax]);
|
src_idx += (idx_val * src.strides()[ax]);
|
||||||
@@ -105,10 +99,9 @@ void gather(
|
|||||||
out_idx += slice_size;
|
out_idx += slice_size;
|
||||||
} else {
|
} else {
|
||||||
for (int jj = 0; jj < slice_size; jj++) {
|
for (int jj = 0; jj < slice_size; jj++) {
|
||||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
auto src_offset = elem_to_loc(jj, slice_sizes, src.strides());
|
||||||
src_it.step();
|
dst_ptr[out_idx++] = src_ptr[src_idx + src_offset];
|
||||||
}
|
}
|
||||||
src_it.reset();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -230,29 +223,21 @@ void scatter(
|
|||||||
update_size *= us;
|
update_size *= us;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end());
|
|
||||||
ContiguousIterator<size_t> update_it(updates);
|
|
||||||
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim());
|
|
||||||
|
|
||||||
for (int i = 0; i < n_updates; ++i) {
|
for (int i = 0; i < n_updates; ++i) {
|
||||||
size_t out_offset = 0;
|
size_t out_offset = 0;
|
||||||
for (int j = 0; j < nind; ++j) {
|
for (int j = 0; j < nind; ++j) {
|
||||||
auto ax = axes[j];
|
auto ax = axes[j];
|
||||||
auto idx_loc = its[j].loc;
|
auto idx_loc = elem_to_loc(i, inds[j]);
|
||||||
its[j].step();
|
|
||||||
auto idx_val =
|
auto idx_val =
|
||||||
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
||||||
out_offset += (idx_val * out.strides()[ax]);
|
out_offset += (idx_val * out.strides()[ax]);
|
||||||
}
|
}
|
||||||
update_it.seek(i * update_size);
|
|
||||||
for (int j = 0; j < update_size; ++j) {
|
for (int j = 0; j < update_size; ++j) {
|
||||||
op(updates.data<InT>()[update_it.loc],
|
auto update_loc = elem_to_loc(i * update_size + j, updates);
|
||||||
out.data<InT>() + out_offset + out_it.loc);
|
auto out_loc = elem_to_loc(j, update_shape, out.strides());
|
||||||
update_it.step();
|
op(updates.data<InT>()[update_loc],
|
||||||
out_it.step();
|
out.data<InT>() + out_offset + out_loc);
|
||||||
}
|
}
|
||||||
out_it.reset();
|
|
||||||
update_it.reset();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -2,94 +2,17 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/lapack.h"
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
int info;
|
#include <Accelerate/Accelerate.h>
|
||||||
MLX_LAPACK_FUNC(strtri)
|
#else
|
||||||
(
|
#include <lapack.h>
|
||||||
/* uplo = */ &uplo,
|
#endif
|
||||||
/* diag = */ &diag,
|
|
||||||
/* N = */ &N,
|
|
||||||
/* a = */ matrix,
|
|
||||||
/* lda = */ &N,
|
|
||||||
/* info = */ &info);
|
|
||||||
return info;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void general_inv(array& inv, int N, int i) {
|
void inverse_impl(const array& a, array& inv) {
|
||||||
int info;
|
|
||||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
|
||||||
// Compute LU factorization.
|
|
||||||
sgetrf_(
|
|
||||||
/* m = */ &N,
|
|
||||||
/* n = */ &N,
|
|
||||||
/* a = */ inv.data<float>() + N * N * i,
|
|
||||||
/* lda = */ &N,
|
|
||||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
|
||||||
/* info = */ &info);
|
|
||||||
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
static const int lwork_query = -1;
|
|
||||||
float workspace_size = 0;
|
|
||||||
|
|
||||||
// Compute workspace size.
|
|
||||||
sgetri_(
|
|
||||||
/* m = */ &N,
|
|
||||||
/* a = */ nullptr,
|
|
||||||
/* lda = */ &N,
|
|
||||||
/* ipiv = */ nullptr,
|
|
||||||
/* work = */ &workspace_size,
|
|
||||||
/* lwork = */ &lwork_query,
|
|
||||||
/* info = */ &info);
|
|
||||||
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
|
||||||
<< info;
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
const int lwork = workspace_size;
|
|
||||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
|
||||||
|
|
||||||
// Compute inverse.
|
|
||||||
sgetri_(
|
|
||||||
/* m = */ &N,
|
|
||||||
/* a = */ inv.data<float>() + N * N * i,
|
|
||||||
/* lda = */ &N,
|
|
||||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
|
||||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
|
||||||
/* lwork = */ &lwork,
|
|
||||||
/* info = */ &info);
|
|
||||||
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "inverse_impl: inversion failed with error code " << info;
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
|
||||||
const char uplo = upper ? 'L' : 'U';
|
|
||||||
const char diag = 'N';
|
|
||||||
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
|
|
||||||
if (info != 0) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "inverse_impl: triangular inversion failed with error code " << info;
|
|
||||||
throw std::runtime_error(ss.str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
|
||||||
// Lapack uses the column-major convention. We take advantage of the following
|
// Lapack uses the column-major convention. We take advantage of the following
|
||||||
// identity to avoid transposing (see
|
// identity to avoid transposing (see
|
||||||
// https://math.stackexchange.com/a/340234):
|
// https://math.stackexchange.com/a/340234):
|
||||||
@@ -101,11 +24,63 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
|||||||
const int N = a.shape(-1);
|
const int N = a.shape(-1);
|
||||||
const size_t num_matrices = a.size() / (N * N);
|
const size_t num_matrices = a.size() / (N * N);
|
||||||
|
|
||||||
|
int info;
|
||||||
|
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||||
|
|
||||||
for (int i = 0; i < num_matrices; i++) {
|
for (int i = 0; i < num_matrices; i++) {
|
||||||
if (tri) {
|
// Compute LU factorization.
|
||||||
tri_inv(inv, N, i, upper);
|
sgetrf_(
|
||||||
} else {
|
/* m = */ &N,
|
||||||
general_inv(inv, N, i);
|
/* n = */ &N,
|
||||||
|
/* a = */ inv.data<float>() + N * N * i,
|
||||||
|
/* lda = */ &N,
|
||||||
|
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
static const int lwork_query = -1;
|
||||||
|
float workspace_size = 0;
|
||||||
|
|
||||||
|
// Compute workspace size.
|
||||||
|
sgetri_(
|
||||||
|
/* m = */ &N,
|
||||||
|
/* a = */ nullptr,
|
||||||
|
/* lda = */ &N,
|
||||||
|
/* ipiv = */ nullptr,
|
||||||
|
/* work = */ &workspace_size,
|
||||||
|
/* lwork = */ &lwork_query,
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||||
|
<< info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
const int lwork = workspace_size;
|
||||||
|
auto scratch =
|
||||||
|
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||||
|
|
||||||
|
// Compute inverse.
|
||||||
|
sgetri_(
|
||||||
|
/* m = */ &N,
|
||||||
|
/* a = */ inv.data<float>() + N * N * i,
|
||||||
|
/* lda = */ &N,
|
||||||
|
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||||
|
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||||
|
/* lwork = */ &lwork,
|
||||||
|
/* info = */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "inverse_impl: inversion failed with error code " << info;
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -114,7 +89,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
|||||||
if (inputs[0].dtype() != float32) {
|
if (inputs[0].dtype() != float32) {
|
||||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||||
}
|
}
|
||||||
inverse_impl(inputs[0], output, tri_, upper_);
|
inverse_impl(inputs[0], output);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -1,11 +1,10 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#ifdef ACCELERATE_NEW_LAPACK
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
#else
|
#else
|
||||||
#include <cblas.h>
|
|
||||||
#include <lapack.h>
|
#include <lapack.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
@@ -5,9 +5,11 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/load.h"
|
#include "mlx/io/load.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <const uint8_t scalar_size>
|
template <const uint8_t scalar_size>
|
||||||
@@ -27,14 +29,12 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace mlx::core {
|
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
void load(
|
reader_->seek(offset_, std::ios_base::beg);
|
||||||
array& out,
|
reader_->read(out.data<char>(), out.nbytes());
|
||||||
size_t offset,
|
|
||||||
const std::shared_ptr<io::Reader>& reader,
|
|
||||||
bool swap_endianness_) {
|
|
||||||
reader->read(out.data<char>(), out.nbytes(), offset);
|
|
||||||
|
|
||||||
if (swap_endianness_) {
|
if (swap_endianness_) {
|
||||||
switch (out.itemsize()) {
|
switch (out.itemsize()) {
|
||||||
@@ -51,11 +51,4 @@ void load(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
|
||||||
assert(inputs.size() == 0);
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
|
||||||
|
|
||||||
load(out, offset_, reader_, swap_endianness_);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -1,14 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/array.h"
|
|
||||||
#include "mlx/io/load.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
void load(
|
|
||||||
array& out,
|
|
||||||
size_t offset,
|
|
||||||
const std::shared_ptr<io::Reader>& reader,
|
|
||||||
bool swap_endianess);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@@ -18,12 +18,10 @@ if [ "$CLANG" = "TRUE" ]; then
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
EOM
|
EOM
|
||||||
CC_FLAGS=""
|
|
||||||
else
|
|
||||||
CC_FLAGS="-std=c++17"
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
|
CONTENT=$($GCC -I $SRCDIR -E $SRCDIR/mlx/backend/common/compiled_preamble.h 2>/dev/null)
|
||||||
|
|
||||||
cat << EOF > "$OUTPUT_FILE"
|
cat << EOF > "$OUTPUT_FILE"
|
||||||
const char* get_kernel_preamble() {
|
const char* get_kernel_preamble() {
|
||||||
|
@@ -1,10 +1,15 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
|
#include <Accelerate/Accelerate.h>
|
||||||
|
#else
|
||||||
|
#include <cblas.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/lapack.h"
|
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
@@ -295,13 +295,6 @@ struct Floor {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Imag {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return std::imag(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Log {
|
struct Log {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
@@ -344,13 +337,6 @@ struct Negative {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Real {
|
|
||||||
template <typename T>
|
|
||||||
T operator()(T x) {
|
|
||||||
return std::real(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Round {
|
struct Round {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T operator()(T x) {
|
T operator()(T x) {
|
||||||
@@ -387,10 +373,6 @@ struct Sign {
|
|||||||
uint64_t operator()(uint64_t x) {
|
uint64_t operator()(uint64_t x) {
|
||||||
return x != 0;
|
return x != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
complex64_t operator()(complex64_t x) {
|
|
||||||
return x == complex64_t(0) ? x : x / std::abs(x);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Sin {
|
struct Sin {
|
||||||
|
@@ -273,10 +273,6 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
copy(in, out, ctype);
|
copy(in, out, ctype);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Log::eval(const std::vector<array>& inputs, array& out) {
|
void Log::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
@@ -402,10 +398,6 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|
||||||
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
const auto& in = inputs[0];
|
const auto& in = inputs[0];
|
||||||
@@ -413,8 +405,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
|
|
||||||
if (copy_necessary) {
|
if (copy_necessary) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
|
||||||
copy_inplace(in, out, CopyType::General);
|
|
||||||
} else {
|
} else {
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
}
|
}
|
||||||
@@ -504,16 +495,8 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
/* int64_t o_offset = */ 0,
|
/* int64_t o_offset = */ 0,
|
||||||
/* CopyType ctype = */ CopyType::General);
|
/* CopyType ctype = */ CopyType::General);
|
||||||
} else {
|
} else {
|
||||||
size_t data_end = 1;
|
|
||||||
for (int i = 0; i < end_indices_.size(); ++i) {
|
|
||||||
if (in.shape()[i] > 1) {
|
|
||||||
auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1;
|
|
||||||
data_end += end_idx * in.strides()[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
size_t data_size = data_end - data_offset;
|
|
||||||
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
|
||||||
shared_buffer_slice(in, ostrides, data_offset, data_size, out);
|
shared_buffer_slice(in, ostrides, data_offset, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -611,18 +594,11 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
strides[i] /= obytes;
|
strides[i] /= obytes;
|
||||||
}
|
}
|
||||||
out.copy_shared_buffer(
|
out.copy_shared_buffer(
|
||||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
in, strides, in.flags(), in.data_size() * obytes / ibytes);
|
||||||
} else {
|
} else {
|
||||||
auto tmp = array(
|
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||||
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
|
|
||||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||||
if (in.dtype() == bool_) {
|
copy_inplace(in, tmp, CopyType::General);
|
||||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
|
||||||
in_tmp.copy_shared_buffer(in);
|
|
||||||
copy_inplace(in_tmp, tmp, CopyType::General);
|
|
||||||
} else {
|
|
||||||
copy_inplace(in, tmp, CopyType::General);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto flags = out.flags();
|
auto flags = out.flags();
|
||||||
flags.contiguous = true;
|
flags.contiguous = true;
|
||||||
|
@@ -2,9 +2,14 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/lapack.h"
|
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#ifdef ACCELERATE_NEW_LAPACK
|
||||||
|
#include <Accelerate/Accelerate.h>
|
||||||
|
#else
|
||||||
|
#include <lapack.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@@ -201,61 +201,55 @@ void _qmm_dispatch(
|
|||||||
int group_size,
|
int group_size,
|
||||||
bool transposed_w) {
|
bool transposed_w) {
|
||||||
int K = x.shape(-1);
|
int K = x.shape(-1);
|
||||||
int M = x.shape(-2);
|
int M = x.size() / K;
|
||||||
int N = out.shape(-1);
|
int N = out.shape(-1);
|
||||||
|
|
||||||
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
|
switch (x.dtype()) {
|
||||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
case float32:
|
||||||
|
_qmm_dispatch_typed<float>(
|
||||||
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
|
out.data<float>(),
|
||||||
for (int i = 0; i < batch_size; i++) {
|
x.data<float>(),
|
||||||
switch (x.dtype()) {
|
w.data<uint32_t>(),
|
||||||
case float32:
|
scales.data<float>(),
|
||||||
_qmm_dispatch_typed<float>(
|
biases.data<float>(),
|
||||||
out.data<float>() + i * M * N,
|
M,
|
||||||
x.data<float>() + elem_to_loc(i * M * K, x),
|
N,
|
||||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
K,
|
||||||
scales.data<float>() + elem_to_loc(i * g_els, scales),
|
bits,
|
||||||
biases.data<float>() + elem_to_loc(i * g_els, biases),
|
group_size,
|
||||||
M,
|
transposed_w);
|
||||||
N,
|
break;
|
||||||
K,
|
case float16:
|
||||||
bits,
|
_qmm_dispatch_typed<float16_t>(
|
||||||
group_size,
|
out.data<float16_t>(),
|
||||||
transposed_w);
|
x.data<float16_t>(),
|
||||||
break;
|
w.data<uint32_t>(),
|
||||||
case float16:
|
scales.data<float16_t>(),
|
||||||
_qmm_dispatch_typed<float16_t>(
|
biases.data<float16_t>(),
|
||||||
out.data<float16_t>() + i * M * N,
|
M,
|
||||||
x.data<float16_t>() + elem_to_loc(i * M * K, x),
|
N,
|
||||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
K,
|
||||||
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
|
bits,
|
||||||
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
|
group_size,
|
||||||
M,
|
transposed_w);
|
||||||
N,
|
break;
|
||||||
K,
|
case bfloat16:
|
||||||
bits,
|
_qmm_dispatch_typed<bfloat16_t>(
|
||||||
group_size,
|
out.data<bfloat16_t>(),
|
||||||
transposed_w);
|
x.data<bfloat16_t>(),
|
||||||
break;
|
w.data<uint32_t>(),
|
||||||
case bfloat16:
|
scales.data<bfloat16_t>(),
|
||||||
_qmm_dispatch_typed<bfloat16_t>(
|
biases.data<bfloat16_t>(),
|
||||||
out.data<bfloat16_t>() + i * M * N,
|
M,
|
||||||
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
|
N,
|
||||||
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
|
K,
|
||||||
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
|
bits,
|
||||||
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
|
group_size,
|
||||||
M,
|
transposed_w);
|
||||||
N,
|
break;
|
||||||
K,
|
default:
|
||||||
bits,
|
throw std::invalid_argument(
|
||||||
group_size,
|
"[quantized_matmul] only floating types are supported");
|
||||||
transposed_w);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[quantized_matmul] only floating types are supported");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -87,38 +87,6 @@ struct OrReduce {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MaxReduce {
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
|
||||||
(*y) = (*y > x) ? *y : x;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
|
||||||
if (std::isnan(x)) {
|
|
||||||
*y = x;
|
|
||||||
} else {
|
|
||||||
(*y) = (*y > x) ? *y : x;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
struct MinReduce {
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<std::is_integral_v<T>> operator()(T* y, T x) {
|
|
||||||
(*y) = (*y < x) ? *y : x;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
std::enable_if_t<!std::is_integral_v<T>> operator()(T* y, T x) {
|
|
||||||
if (std::isnan(x)) {
|
|
||||||
*y = x;
|
|
||||||
} else {
|
|
||||||
(*y) = (*y < x) ? *y : x;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename InT>
|
template <typename InT>
|
||||||
void reduce_dispatch_out(
|
void reduce_dispatch_out(
|
||||||
const array& in,
|
const array& in,
|
||||||
@@ -150,13 +118,15 @@ void reduce_dispatch_out(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Reduce::Max: {
|
case Reduce::Max: {
|
||||||
|
auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; };
|
||||||
auto init = Limits<InT>::min;
|
auto init = Limits<InT>::min;
|
||||||
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
|
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Reduce::Min: {
|
case Reduce::Min: {
|
||||||
|
auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; };
|
||||||
auto init = Limits<InT>::max;
|
auto init = Limits<InT>::max;
|
||||||
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
|
reduction_op<InT, InT>(in, out, axes, init, op);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -49,7 +49,7 @@ struct ReductionPlan {
|
|||||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
|
||||||
|
|
||||||
// Helper for the ndimensional strided loop
|
// Helper for the ndimensional strided loop
|
||||||
// Should this be in utils?
|
// Should this be in utils?
|
||||||
|
@@ -19,7 +19,7 @@ std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
|||||||
return std::make_pair(shape, strides);
|
return std::make_pair(shape, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||||
// The data is all there and we are reducing over everything
|
// The data is all there and we are reducing over everything
|
||||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||||
x.flags().contiguous) {
|
x.flags().contiguous) {
|
||||||
@@ -32,7 +32,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
|||||||
std::vector<int> shape = {x.shape(axes[0])};
|
std::vector<int> shape = {x.shape(axes[0])};
|
||||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||||
for (int i = 1; i < axes.size(); i++) {
|
for (int i = 1; i < axes.size(); i++) {
|
||||||
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
|
if (axes[i] - 1 == axes[i - 1]) {
|
||||||
shape.back() *= x.shape(axes[i]);
|
shape.back() *= x.shape(axes[i]);
|
||||||
strides.back() = x.strides()[axes[i]];
|
strides.back() = x.strides()[axes[i]];
|
||||||
} else {
|
} else {
|
||||||
@@ -41,14 +41,6 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove singleton axes from the plan
|
|
||||||
for (int i = shape.size() - 1; i >= 0; i--) {
|
|
||||||
if (shape[i] == 1) {
|
|
||||||
shape.erase(shape.begin() + i);
|
|
||||||
strides.erase(strides.begin() + i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (strides.back() == 1) {
|
if (strides.back() == 1) {
|
||||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||||
} else if (strides.back() > 1) {
|
} else if (strides.back() > 1) {
|
||||||
@@ -71,14 +63,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
|||||||
// have a contiguous reduction.
|
// have a contiguous reduction.
|
||||||
std::vector<std::pair<int, size_t>> reductions;
|
std::vector<std::pair<int, size_t>> reductions;
|
||||||
for (auto a : axes) {
|
for (auto a : axes) {
|
||||||
if (x.shape(a) > 1) {
|
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||||
bool a_is_zero = a.second == 0;
|
return a.second > b.second;
|
||||||
bool b_is_zero = b.second == 0;
|
|
||||||
return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
|
|
||||||
});
|
});
|
||||||
// Extract the two smallest and try to merge them in case the contiguous
|
// Extract the two smallest and try to merge them in case the contiguous
|
||||||
// reduction can be bigger than just the last axis.
|
// reduction can be bigger than just the last axis.
|
||||||
@@ -110,33 +98,16 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
|||||||
// strides.back() are contiguous.
|
// strides.back() are contiguous.
|
||||||
if (strides.back() > 1) {
|
if (strides.back() > 1) {
|
||||||
int size = 1;
|
int size = 1;
|
||||||
bool have_expand = false;
|
|
||||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||||
if (axes.back() == i) {
|
if (axes.back() == i) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (x.strides()[i] != size) {
|
||||||
size_t stride_i = x.strides()[i];
|
|
||||||
int shape_i = x.shape(i);
|
|
||||||
if (stride_i == 0) {
|
|
||||||
if (shape_i == 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
have_expand = true;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
size *= x.shape(i);
|
||||||
if (stride_i != size && shape_i != 1) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
size *= shape_i;
|
|
||||||
}
|
}
|
||||||
// In the case of an expanded dimension we are being conservative and
|
if (size >= strides.back()) {
|
||||||
// require the smallest reduction stride to be smaller than the maximum row
|
|
||||||
// contiguous size. The reason is that we can't easily know if the reduced
|
|
||||||
// axis is before or after an expanded dimension.
|
|
||||||
if (size > strides.back() || (size == strides.back() && !have_expand)) {
|
|
||||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -6,16 +6,18 @@ namespace mlx::core {
|
|||||||
|
|
||||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||||
const array& in,
|
const array& in,
|
||||||
const std::vector<int>& start_indices,
|
std::vector<int>& start_indices,
|
||||||
const std::vector<int>& strides) {
|
std::vector<int>& strides) {
|
||||||
int64_t data_offset = 0;
|
int64_t data_offset = 0;
|
||||||
bool copy_needed = false;
|
bool copy_needed = false;
|
||||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||||
for (int i = 0; i < in.ndim(); ++i) {
|
for (int i = 0; i < in.ndim(); ++i) {
|
||||||
data_offset += start_indices[i] * in.strides()[i];
|
data_offset += start_indices[i] * in.strides()[i];
|
||||||
inp_strides[i] = in.strides()[i] * strides[i];
|
inp_strides[i] = in.strides()[i] * strides[i];
|
||||||
|
|
||||||
copy_needed |= strides[i] < 0;
|
copy_needed |= strides[i] < 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,16 +25,26 @@ void shared_buffer_slice(
|
|||||||
const array& in,
|
const array& in,
|
||||||
const std::vector<size_t>& out_strides,
|
const std::vector<size_t>& out_strides,
|
||||||
size_t data_offset,
|
size_t data_offset,
|
||||||
size_t data_size,
|
|
||||||
array& out) {
|
array& out) {
|
||||||
// Compute row/col contiguity
|
// Compute row/col contiguity
|
||||||
auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
|
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||||
check_contiguity(out.shape(), out_strides);
|
check_contiguity(out.shape(), out_strides);
|
||||||
|
|
||||||
auto flags = in.flags();
|
auto flags = in.flags();
|
||||||
flags.row_contiguous = is_row_contiguous;
|
flags.row_contiguous = is_row_contiguous;
|
||||||
flags.col_contiguous = is_col_contiguous;
|
flags.col_contiguous = is_col_contiguous;
|
||||||
flags.contiguous = (no_bsx_size == data_size);
|
|
||||||
|
if (data_size == 1) {
|
||||||
|
// Broadcasted scalar array is contiguous.
|
||||||
|
flags.contiguous = true;
|
||||||
|
} else if (data_size == in.data_size()) {
|
||||||
|
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||||
|
// alone.
|
||||||
|
} else {
|
||||||
|
// We sliced something. So either we are row or col contiguous or we
|
||||||
|
// punched a hole.
|
||||||
|
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||||
|
}
|
||||||
|
|
||||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||||
}
|
}
|
||||||
|
@@ -8,14 +8,13 @@ namespace mlx::core {
|
|||||||
|
|
||||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||||
const array& in,
|
const array& in,
|
||||||
const std::vector<int>& start_indices,
|
std::vector<int>& start_indices,
|
||||||
const std::vector<int>& strides);
|
std::vector<int>& strides);
|
||||||
|
|
||||||
void shared_buffer_slice(
|
void shared_buffer_slice(
|
||||||
const array& in,
|
const array& in,
|
||||||
const std::vector<size_t>& out_strides,
|
const std::vector<size_t>& out_strides,
|
||||||
size_t data_offset,
|
size_t data_offset,
|
||||||
size_t data_size,
|
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@@ -111,29 +111,26 @@ void sort(const array& in, array& out, int axis) {
|
|||||||
|
|
||||||
// Get axis, shape and stride info
|
// Get axis, shape and stride info
|
||||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
size_t n_rows = in.size() / in.shape(axis);
|
||||||
size_t n_rows = in_size / in.shape(axis);
|
|
||||||
|
|
||||||
auto remaining_shape = out.shape();
|
auto remaining_shape = in.shape();
|
||||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||||
|
|
||||||
auto remaining_strides = out.strides();
|
auto remaining_strides = in.strides();
|
||||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||||
|
|
||||||
size_t axis_stride = out.strides()[axis];
|
size_t axis_stride = in.strides()[axis];
|
||||||
int axis_size = out.shape(axis);
|
int axis_size = in.shape(axis);
|
||||||
|
|
||||||
// Perform sorting in place
|
// Perform sorting in place
|
||||||
ContiguousIterator<size_t> src_it(
|
|
||||||
remaining_shape, remaining_strides, remaining_shape.size());
|
|
||||||
for (int i = 0; i < n_rows; i++) {
|
for (int i = 0; i < n_rows; i++) {
|
||||||
T* data_ptr = out.data<T>() + src_it.loc;
|
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||||
|
T* data_ptr = out.data<T>() + loc;
|
||||||
|
|
||||||
StridedIterator st(data_ptr, axis_stride, 0);
|
StridedIterator st(data_ptr, axis_stride, 0);
|
||||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
std::stable_sort(st, ed);
|
std::stable_sort(st, ed);
|
||||||
src_it.step();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,46 +143,34 @@ void argsort(const array& in, array& out, int axis) {
|
|||||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||||
size_t n_rows = in.size() / in.shape(axis);
|
size_t n_rows = in.size() / in.shape(axis);
|
||||||
|
|
||||||
auto in_remaining_shape = in.shape();
|
auto remaining_shape = in.shape();
|
||||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||||
|
|
||||||
auto in_remaining_strides = in.strides();
|
auto remaining_strides = in.strides();
|
||||||
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
|
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||||
|
|
||||||
auto out_remaining_shape = out.shape();
|
size_t axis_stride = in.strides()[axis];
|
||||||
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
|
|
||||||
|
|
||||||
auto out_remaining_strides = out.strides();
|
|
||||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
|
||||||
|
|
||||||
size_t in_stride = in.strides()[axis];
|
|
||||||
size_t out_stride = out.strides()[axis];
|
|
||||||
int axis_size = in.shape(axis);
|
int axis_size = in.shape(axis);
|
||||||
|
|
||||||
// Perform sorting
|
// Perform sorting
|
||||||
ContiguousIterator<size_t> in_it(
|
|
||||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
|
||||||
ContiguousIterator<size_t> out_it(
|
|
||||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
|
||||||
for (int i = 0; i < n_rows; i++) {
|
for (int i = 0; i < n_rows; i++) {
|
||||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||||
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
|
const T* data_ptr = in.data<T>() + loc;
|
||||||
in_it.step();
|
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||||
out_it.step();
|
|
||||||
|
|
||||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
// Initialize with iota
|
// Initialize with iota
|
||||||
std::iota(st_, ed_, IdxT(0));
|
std::iota(st_, ed_, IdxT(0));
|
||||||
|
|
||||||
// Sort according to vals
|
// Sort according to vals
|
||||||
StridedIterator st(idx_ptr, out_stride, 0);
|
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||||
auto v1 = data_ptr[a * in_stride];
|
auto v1 = data_ptr[a * axis_stride];
|
||||||
auto v2 = data_ptr[b * in_stride];
|
auto v2 = data_ptr[b * axis_stride];
|
||||||
return v1 < v2 || (v1 == v2 && a < b);
|
return v1 < v2 || (v1 == v2 && a < b);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -199,8 +184,7 @@ void partition(const array& in, array& out, int axis, int kth) {
|
|||||||
|
|
||||||
// Get axis, shape and stride info
|
// Get axis, shape and stride info
|
||||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
size_t n_rows = in.size() / in.shape(axis);
|
||||||
size_t n_rows = in_size / in.shape(axis);
|
|
||||||
|
|
||||||
auto remaining_shape = in.shape();
|
auto remaining_shape = in.shape();
|
||||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||||
@@ -214,11 +198,9 @@ void partition(const array& in, array& out, int axis, int kth) {
|
|||||||
kth = kth < 0 ? kth + axis_size : kth;
|
kth = kth < 0 ? kth + axis_size : kth;
|
||||||
|
|
||||||
// Perform partition in place
|
// Perform partition in place
|
||||||
ContiguousIterator<size_t> src_it(
|
|
||||||
remaining_shape, remaining_strides, remaining_shape.size());
|
|
||||||
for (int i = 0; i < n_rows; i++) {
|
for (int i = 0; i < n_rows; i++) {
|
||||||
T* data_ptr = out.data<T>() + src_it.loc;
|
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||||
src_it.step();
|
T* data_ptr = out.data<T>() + loc;
|
||||||
|
|
||||||
StridedIterator st(data_ptr, axis_stride, 0);
|
StridedIterator st(data_ptr, axis_stride, 0);
|
||||||
StridedIterator md(data_ptr, axis_stride, kth);
|
StridedIterator md(data_ptr, axis_stride, kth);
|
||||||
@@ -237,49 +219,37 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
|||||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||||
size_t n_rows = in.size() / in.shape(axis);
|
size_t n_rows = in.size() / in.shape(axis);
|
||||||
|
|
||||||
auto in_remaining_shape = in.shape();
|
auto remaining_shape = in.shape();
|
||||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||||
|
|
||||||
auto in_remaining_strides = in.strides();
|
auto remaining_strides = in.strides();
|
||||||
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
|
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||||
|
|
||||||
auto out_remaining_shape = out.shape();
|
size_t axis_stride = in.strides()[axis];
|
||||||
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
|
|
||||||
|
|
||||||
auto out_remaining_strides = out.strides();
|
|
||||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
|
||||||
|
|
||||||
size_t in_stride = in.strides()[axis];
|
|
||||||
size_t out_stride = out.strides()[axis];
|
|
||||||
int axis_size = in.shape(axis);
|
int axis_size = in.shape(axis);
|
||||||
|
|
||||||
kth = kth < 0 ? kth + axis_size : kth;
|
kth = kth < 0 ? kth + axis_size : kth;
|
||||||
|
|
||||||
// Perform partition
|
// Perform partition
|
||||||
ContiguousIterator<size_t> in_it(
|
|
||||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
|
||||||
ContiguousIterator<size_t> out_it(
|
|
||||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
|
||||||
for (int i = 0; i < n_rows; i++) {
|
for (int i = 0; i < n_rows; i++) {
|
||||||
const T* data_ptr = in.data<T>() + in_it.loc;
|
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||||
IdxT* idx_ptr = out.data<IdxT>() + out_it.loc;
|
const T* data_ptr = in.data<T>() + loc;
|
||||||
in_it.step();
|
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||||
out_it.step();
|
|
||||||
|
|
||||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
// Initialize with iota
|
// Initialize with iota
|
||||||
std::iota(st_, ed_, IdxT(0));
|
std::iota(st_, ed_, IdxT(0));
|
||||||
|
|
||||||
// Sort according to vals
|
// Sort according to vals
|
||||||
StridedIterator st(idx_ptr, out_stride, 0);
|
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||||
StridedIterator md(idx_ptr, out_stride, kth);
|
StridedIterator md(idx_ptr, axis_stride, kth);
|
||||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||||
|
|
||||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||||
auto v1 = data_ptr[a * in_stride];
|
auto v1 = data_ptr[a * axis_stride];
|
||||||
auto v2 = data_ptr[b * in_stride];
|
auto v2 = data_ptr[b * axis_stride];
|
||||||
return v1 < v2 || (v1 == v2 && a < b);
|
return v1 < v2 || (v1 == v2 && a < b);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/backend/common/copy.h"
|
#include "mlx/backend/common/copy.h"
|
||||||
#include "mlx/backend/common/lapack.h"
|
#include "mlx/backend/common/lapack_helper.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
@@ -12,7 +12,6 @@ namespace {
|
|||||||
// TODO: Add support for more combinations of input types.
|
// TODO: Add support for more combinations of input types.
|
||||||
enum class TernaryOpType {
|
enum class TernaryOpType {
|
||||||
ScalarScalarScalar,
|
ScalarScalarScalar,
|
||||||
VectorVectorVector,
|
|
||||||
General,
|
General,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -21,12 +20,6 @@ get_ternary_op_type(const array& a, const array& b, const array& c) {
|
|||||||
TernaryOpType topt;
|
TernaryOpType topt;
|
||||||
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
|
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
|
||||||
topt = TernaryOpType::ScalarScalarScalar;
|
topt = TernaryOpType::ScalarScalarScalar;
|
||||||
} else if (
|
|
||||||
(a.flags().row_contiguous && b.flags().row_contiguous &&
|
|
||||||
c.flags().row_contiguous) ||
|
|
||||||
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
|
||||||
c.flags().col_contiguous)) {
|
|
||||||
topt = TernaryOpType::VectorVectorVector;
|
|
||||||
} else {
|
} else {
|
||||||
topt = TernaryOpType::General;
|
topt = TernaryOpType::General;
|
||||||
}
|
}
|
||||||
@@ -40,77 +33,138 @@ void set_ternary_op_output_data(
|
|||||||
array& out,
|
array& out,
|
||||||
TernaryOpType topt,
|
TernaryOpType topt,
|
||||||
bool donate_with_move = false) {
|
bool donate_with_move = false) {
|
||||||
auto maybe_donate = [&out, donate_with_move](const array& x) {
|
|
||||||
if (is_donatable(x, out)) {
|
|
||||||
if (donate_with_move) {
|
|
||||||
out.move_shared_buffer(x);
|
|
||||||
} else {
|
|
||||||
out.copy_shared_buffer(x);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
switch (topt) {
|
switch (topt) {
|
||||||
case TernaryOpType::ScalarScalarScalar:
|
case TernaryOpType::ScalarScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||||
break;
|
break;
|
||||||
case TernaryOpType::VectorVectorVector:
|
|
||||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
|
||||||
out.set_data(
|
|
||||||
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
|
||||||
b.data_size(),
|
|
||||||
b.strides(),
|
|
||||||
b.flags());
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case TernaryOpType::General:
|
case TernaryOpType::General:
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template <typename T1, typename T2, typename T3, typename U, typename Op, int D>
|
|
||||||
void ternary_op_dims(
|
|
||||||
const T1* a,
|
|
||||||
const T2* b,
|
|
||||||
const T3* c,
|
|
||||||
U* out,
|
|
||||||
Op op,
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<size_t>& a_strides,
|
|
||||||
const std::vector<size_t>& b_strides,
|
|
||||||
const std::vector<size_t>& c_strides,
|
|
||||||
const std::vector<size_t>& out_strides,
|
|
||||||
int axis) {
|
|
||||||
auto stride_a = a_strides[axis];
|
|
||||||
auto stride_b = b_strides[axis];
|
|
||||||
auto stride_c = c_strides[axis];
|
|
||||||
auto stride_out = out_strides[axis];
|
|
||||||
auto N = shape[axis];
|
|
||||||
|
|
||||||
for (int i = 0; i < N; i++) {
|
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||||
if constexpr (D > 1) {
|
void ternary_op_dims1(
|
||||||
ternary_op_dims<T1, T2, T3, U, Op, D - 1>(
|
const array& a,
|
||||||
a,
|
const array& b,
|
||||||
b,
|
const array& c,
|
||||||
c,
|
array& out,
|
||||||
out,
|
Op op) {
|
||||||
op,
|
const T1* a_ptr = a.data<T1>();
|
||||||
shape,
|
const T2* b_ptr = b.data<T2>();
|
||||||
a_strides,
|
const T3* c_ptr = c.data<T3>();
|
||||||
b_strides,
|
|
||||||
c_strides,
|
U* dst = out.data<U>();
|
||||||
out_strides,
|
size_t a_idx = 0;
|
||||||
axis + 1);
|
size_t b_idx = 0;
|
||||||
} else {
|
size_t c_idx = 0;
|
||||||
*out = op(*a, *b, *c);
|
for (size_t i = 0; i < out.size(); ++i) {
|
||||||
|
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||||
|
a_idx += a.strides()[0];
|
||||||
|
b_idx += b.strides()[0];
|
||||||
|
c_idx += c.strides()[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||||
|
void ternary_op_dims2(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
array& out,
|
||||||
|
Op op) {
|
||||||
|
const T1* a_ptr = a.data<T1>();
|
||||||
|
const T2* b_ptr = b.data<T2>();
|
||||||
|
const T3* c_ptr = c.data<T3>();
|
||||||
|
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t c_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||||
|
a_idx += a.strides()[1];
|
||||||
|
b_idx += b.strides()[1];
|
||||||
|
c_idx += c.strides()[1];
|
||||||
}
|
}
|
||||||
a += stride_a;
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
b += stride_b;
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
c += stride_c;
|
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
|
||||||
out += stride_out;
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||||
|
void ternary_op_dims3(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
array& out,
|
||||||
|
Op op) {
|
||||||
|
const T1* a_ptr = a.data<T1>();
|
||||||
|
const T2* b_ptr = b.data<T2>();
|
||||||
|
const T3* c_ptr = c.data<T3>();
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t c_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||||
|
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||||
|
a_idx += a.strides()[2];
|
||||||
|
b_idx += b.strides()[2];
|
||||||
|
c_idx += c.strides()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||||
|
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||||
|
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T1, typename T2, typename T3, typename U, typename Op>
|
||||||
|
void ternary_op_dims4(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const array& c,
|
||||||
|
array& out,
|
||||||
|
Op op) {
|
||||||
|
const T1* a_ptr = a.data<T1>();
|
||||||
|
const T2* b_ptr = b.data<T2>();
|
||||||
|
const T3* c_ptr = c.data<T3>();
|
||||||
|
|
||||||
|
U* dst = out.data<U>();
|
||||||
|
size_t a_idx = 0;
|
||||||
|
size_t b_idx = 0;
|
||||||
|
size_t c_idx = 0;
|
||||||
|
size_t out_idx = 0;
|
||||||
|
for (size_t i = 0; i < a.shape()[0]; ++i) {
|
||||||
|
for (size_t j = 0; j < a.shape()[1]; ++j) {
|
||||||
|
for (size_t k = 0; k < a.shape()[2]; ++k) {
|
||||||
|
for (size_t ii = 0; ii < a.shape()[3]; ++ii) {
|
||||||
|
dst[out_idx++] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||||
|
a_idx += a.strides()[3];
|
||||||
|
b_idx += b.strides()[3];
|
||||||
|
c_idx += c.strides()[3];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[2] - a.strides()[3] * a.shape()[3];
|
||||||
|
b_idx += b.strides()[2] - b.strides()[3] * b.shape()[3];
|
||||||
|
c_idx += c.strides()[2] - c.strides()[3] * c.shape()[3];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[1] - a.strides()[2] * a.shape()[2];
|
||||||
|
b_idx += b.strides()[1] - b.strides()[2] * b.shape()[2];
|
||||||
|
c_idx += c.strides()[1] - c.strides()[2] * c.shape()[2];
|
||||||
|
}
|
||||||
|
a_idx += a.strides()[0] - a.strides()[1] * a.shape()[1];
|
||||||
|
b_idx += b.strides()[0] - b.strides()[1] * b.shape()[1];
|
||||||
|
c_idx += c.strides()[0] - c.strides()[1] * c.shape()[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,69 +175,30 @@ void ternary_op_dispatch_dims(
|
|||||||
const array& c,
|
const array& c,
|
||||||
array& out,
|
array& out,
|
||||||
Op op) {
|
Op op) {
|
||||||
auto [shape, strides] = collapse_contiguous_dims(
|
switch (out.ndim()) {
|
||||||
a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()});
|
case 1:
|
||||||
const auto& a_strides = strides[0];
|
ternary_op_dims1<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||||
const auto& b_strides = strides[1];
|
return;
|
||||||
const auto& c_strides = strides[2];
|
case 2:
|
||||||
const auto& out_strides = strides[3];
|
ternary_op_dims2<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||||
|
return;
|
||||||
|
case 3:
|
||||||
|
ternary_op_dims3<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||||
|
return;
|
||||||
|
case 4:
|
||||||
|
ternary_op_dims4<T1, T2, T3, U, Op>(a, b, c, out, op);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const T1* a_ptr = a.data<T1>();
|
const T1* a_ptr = a.data<T1>();
|
||||||
const T2* b_ptr = b.data<T2>();
|
const T2* b_ptr = b.data<T2>();
|
||||||
const T3* c_ptr = c.data<T3>();
|
const T3* c_ptr = c.data<T3>();
|
||||||
U* out_ptr = out.data<T3>();
|
U* dst = out.data<U>();
|
||||||
int ndim = shape.size();
|
for (size_t i = 0; i < out.size(); i++) {
|
||||||
switch (ndim) {
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
case 1:
|
int b_idx = elem_to_loc(i, b.shape(), b.strides());
|
||||||
ternary_op_dims<T1, T2, T3, U, Op, 1>(
|
int c_idx = elem_to_loc(i, c.shape(), c.strides());
|
||||||
a_ptr,
|
dst[i] = op(a_ptr[a_idx], b_ptr[b_idx], c_ptr[c_idx]);
|
||||||
b_ptr,
|
|
||||||
c_ptr,
|
|
||||||
out_ptr,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
c_strides,
|
|
||||||
out_strides,
|
|
||||||
0);
|
|
||||||
return;
|
|
||||||
case 2:
|
|
||||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
|
||||||
a_ptr,
|
|
||||||
b_ptr,
|
|
||||||
c_ptr,
|
|
||||||
out_ptr,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
c_strides,
|
|
||||||
out_strides,
|
|
||||||
0);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2);
|
|
||||||
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2);
|
|
||||||
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2);
|
|
||||||
size_t stride = out_strides[ndim - 3];
|
|
||||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
|
||||||
ternary_op_dims<T1, T2, T3, U, Op, 2>(
|
|
||||||
a_ptr + a_it.loc,
|
|
||||||
b_ptr + b_it.loc,
|
|
||||||
c_ptr + c_it.loc,
|
|
||||||
out_ptr + elem,
|
|
||||||
op,
|
|
||||||
shape,
|
|
||||||
a_strides,
|
|
||||||
b_strides,
|
|
||||||
c_strides,
|
|
||||||
out_strides,
|
|
||||||
ndim - 2);
|
|
||||||
a_it.step();
|
|
||||||
b_it.step();
|
|
||||||
c_it.step();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -200,21 +215,10 @@ void ternary_op(
|
|||||||
// The full computation is scalar-scalar-scalar so we call the base op once.
|
// The full computation is scalar-scalar-scalar so we call the base op once.
|
||||||
if (topt == TernaryOpType::ScalarScalarScalar) {
|
if (topt == TernaryOpType::ScalarScalarScalar) {
|
||||||
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
|
*(out.data<U>()) = op(*a.data<T1>(), *b.data<T2>(), *c.data<T3>());
|
||||||
} else if (topt == TernaryOpType::VectorVectorVector) {
|
return;
|
||||||
const T1* a_ptr = a.data<T1>();
|
|
||||||
const T2* b_ptr = b.data<T2>();
|
|
||||||
const T3* c_ptr = c.data<T3>();
|
|
||||||
U* out_ptr = out.data<U>();
|
|
||||||
for (size_t i = 0; i < out.size(); ++i) {
|
|
||||||
*out_ptr = op(*a_ptr, *b_ptr, *c_ptr);
|
|
||||||
a_ptr++;
|
|
||||||
b_ptr++;
|
|
||||||
c_ptr++;
|
|
||||||
out_ptr++;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ternary_op_dispatch_dims<T1, T2, T3, U>(a, b, c, out, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@@ -12,7 +12,7 @@ namespace mlx::core {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void set_unary_output_data(const array& in, array& out) {
|
void set_unary_output_data(const array& in, array& out) {
|
||||||
if (is_donatable(in, out)) {
|
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
@@ -24,36 +24,22 @@ void set_unary_output_data(const array& in, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U = T, typename Op>
|
template <typename T, typename Op>
|
||||||
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
|
|
||||||
for (size_t i = 0; i < shape; i += 1) {
|
|
||||||
out[i] = op(*a);
|
|
||||||
a += stride;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename U = T, typename Op>
|
|
||||||
void unary_op(const array& a, array& out, Op op) {
|
void unary_op(const array& a, array& out, Op op) {
|
||||||
const T* a_ptr = a.data<T>();
|
const T* a_ptr = a.data<T>();
|
||||||
if (a.flags().contiguous) {
|
if (a.flags().contiguous) {
|
||||||
set_unary_output_data(a, out);
|
set_unary_output_data(a, out);
|
||||||
U* dst = out.data<U>();
|
T* dst = out.data<T>();
|
||||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||||
dst[i] = op(a_ptr[i]);
|
dst[i] = op(a_ptr[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
U* dst = out.data<U>();
|
T* dst = out.data<T>();
|
||||||
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
|
for (size_t i = 0; i < out.size(); ++i) {
|
||||||
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
|
// TODO this is super inefficient, need to fix.
|
||||||
if (a.ndim() <= 1) {
|
int a_idx = elem_to_loc(i, a.shape(), a.strides());
|
||||||
unary_op(a_ptr, dst, op, shape, stride);
|
dst[i] = op(a_ptr[a_idx]);
|
||||||
return;
|
|
||||||
}
|
|
||||||
ContiguousIterator it(a.shape(), a.strides(), a.ndim() - 1);
|
|
||||||
for (size_t elem = 0; elem < a.size(); elem += shape) {
|
|
||||||
unary_op(a_ptr + it.loc, dst + elem, op, shape, stride);
|
|
||||||
it.step();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,138 +0,0 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
template <typename StrideT>
|
|
||||||
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
|
|
||||||
collapse_contiguous_dims_impl(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<std::vector<StrideT>>& strides,
|
|
||||||
StrideT size_cap) {
|
|
||||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
|
||||||
// -1.
|
|
||||||
std::vector<int> to_collapse;
|
|
||||||
if (shape.size() > 0) {
|
|
||||||
if (shape[0] != 1) {
|
|
||||||
to_collapse.push_back(0);
|
|
||||||
}
|
|
||||||
size_t size = shape[0];
|
|
||||||
for (int i = 1; i < shape.size(); i++) {
|
|
||||||
bool contiguous = true;
|
|
||||||
size *= shape[i];
|
|
||||||
for (const std::vector<StrideT>& st : strides) {
|
|
||||||
if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
|
|
||||||
contiguous = false;
|
|
||||||
size = shape[i];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!contiguous) {
|
|
||||||
to_collapse.push_back(-1);
|
|
||||||
}
|
|
||||||
if (shape[i] != 1) {
|
|
||||||
to_collapse.push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
to_collapse.push_back(-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> out_shape;
|
|
||||||
std::vector<std::vector<StrideT>> out_strides(strides.size());
|
|
||||||
for (int i = 0;;) {
|
|
||||||
while (i < to_collapse.size() && to_collapse[i] == -1) {
|
|
||||||
++i;
|
|
||||||
};
|
|
||||||
if (i == to_collapse.size()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
int current_shape = shape[to_collapse[i]];
|
|
||||||
int k = i;
|
|
||||||
while (to_collapse[++k] != -1) {
|
|
||||||
current_shape *= shape[to_collapse[k]];
|
|
||||||
}
|
|
||||||
out_shape.push_back(current_shape);
|
|
||||||
for (int j = 0; j < strides.size(); j++) {
|
|
||||||
const std::vector<StrideT>& st = strides[j];
|
|
||||||
out_strides[j].push_back(st[to_collapse[k - 1]]);
|
|
||||||
}
|
|
||||||
i = k + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!shape.empty() && out_shape.empty()) {
|
|
||||||
out_shape.push_back(1);
|
|
||||||
for (auto& out_stride : out_strides) {
|
|
||||||
out_stride.push_back(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::make_tuple(out_shape, out_strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
|
||||||
collapse_contiguous_dims(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<std::vector<int64_t>>& strides,
|
|
||||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
|
||||||
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
|
||||||
collapse_contiguous_dims(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<std::vector<size_t>>& strides,
|
|
||||||
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
|
|
||||||
return collapse_contiguous_dims_impl(shape, strides, size_cap);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename StrideT>
|
|
||||||
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<StrideT>& strides,
|
|
||||||
StrideT size_cap) {
|
|
||||||
std::vector<int> collapsed_shape;
|
|
||||||
std::vector<StrideT> collapsed_strides;
|
|
||||||
|
|
||||||
if (shape.size() > 0) {
|
|
||||||
collapsed_shape.push_back(shape[0]);
|
|
||||||
collapsed_strides.push_back(strides[0]);
|
|
||||||
for (int i = 1; i < shape.size(); i++) {
|
|
||||||
if (shape[i] == 1) {
|
|
||||||
continue;
|
|
||||||
} else if (
|
|
||||||
strides[i] * shape[i] != collapsed_strides.back() ||
|
|
||||||
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) {
|
|
||||||
collapsed_shape.push_back(shape[i]);
|
|
||||||
collapsed_strides.push_back(strides[i]);
|
|
||||||
} else {
|
|
||||||
collapsed_shape.back() *= shape[i];
|
|
||||||
collapsed_strides.back() = strides[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::make_pair(collapsed_shape, collapsed_strides);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<int64_t>& strides,
|
|
||||||
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
|
||||||
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<size_t>& strides,
|
|
||||||
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
|
|
||||||
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
|
||||||
const array& a,
|
|
||||||
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
|
|
||||||
return collapse_contiguous_dims_impl<size_t>(
|
|
||||||
a.shape(), a.strides(), size_cap);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
@@ -8,12 +8,12 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
template <typename StrideT>
|
template <typename stride_t>
|
||||||
inline StrideT elem_to_loc(
|
inline stride_t elem_to_loc(
|
||||||
int elem,
|
int elem,
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
const std::vector<StrideT>& strides) {
|
const std::vector<stride_t>& strides) {
|
||||||
StrideT loc = 0;
|
stride_t loc = 0;
|
||||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||||
auto q_and_r = ldiv(elem, shape[i]);
|
auto q_and_r = ldiv(elem, shape[i]);
|
||||||
loc += q_and_r.rem * strides[i];
|
loc += q_and_r.rem * strides[i];
|
||||||
@@ -29,41 +29,64 @@ inline size_t elem_to_loc(int elem, const array& a) {
|
|||||||
return elem_to_loc(elem, a.shape(), a.strides());
|
return elem_to_loc(elem, a.shape(), a.strides());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename StrideT>
|
|
||||||
std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
|
|
||||||
std::vector<StrideT> strides(shape.size(), 1);
|
|
||||||
for (int i = shape.size() - 1; i > 0; i--) {
|
|
||||||
strides[i - 1] = strides[i] * shape[i];
|
|
||||||
}
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collapse dims that are contiguous to possibly route to a better kernel
|
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||||
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||||
// should return {{2, 4}, {{1, 2}}}.
|
// should return {{2, 4}, {{1, 2}}}.
|
||||||
//
|
//
|
||||||
// When multiple arrays are passed they should all have the same shape. The
|
// When multiple arrays are passed they should all have the same shape. The
|
||||||
// collapsed axes are also the same so one shape is returned.
|
// collapsed axes are also the same so one shape is returned.
|
||||||
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>>
|
template <typename stride_t>
|
||||||
|
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
|
||||||
collapse_contiguous_dims(
|
collapse_contiguous_dims(
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
const std::vector<std::vector<int64_t>>& strides,
|
const std::vector<std::vector<stride_t>> strides) {
|
||||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
// -1.
|
||||||
collapse_contiguous_dims(
|
std::vector<int> to_collapse;
|
||||||
const std::vector<int>& shape,
|
if (shape.size() > 0) {
|
||||||
const std::vector<std::vector<size_t>>& strides,
|
to_collapse.push_back(0);
|
||||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
for (int i = 1; i < shape.size(); i++) {
|
||||||
|
bool contiguous = true;
|
||||||
|
for (const std::vector<stride_t>& st : strides) {
|
||||||
|
if (st[i] * shape[i] != st[i - 1]) {
|
||||||
|
contiguous = false;
|
||||||
|
}
|
||||||
|
if (!contiguous) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!contiguous) {
|
||||||
|
to_collapse.push_back(-1);
|
||||||
|
}
|
||||||
|
to_collapse.push_back(i);
|
||||||
|
}
|
||||||
|
to_collapse.push_back(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> out_shape;
|
||||||
|
std::vector<std::vector<stride_t>> out_strides(strides.size());
|
||||||
|
for (int i = 0; i < to_collapse.size(); i++) {
|
||||||
|
int current_shape = shape[to_collapse[i]];
|
||||||
|
while (to_collapse[++i] != -1) {
|
||||||
|
current_shape *= shape[to_collapse[i]];
|
||||||
|
}
|
||||||
|
out_shape.push_back(current_shape);
|
||||||
|
for (int j = 0; j < strides.size(); j++) {
|
||||||
|
const std::vector<stride_t>& st = strides[j];
|
||||||
|
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(out_shape, out_strides);
|
||||||
|
}
|
||||||
|
|
||||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||||
collapse_contiguous_dims(
|
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||||
const std::vector<array>& xs,
|
|
||||||
size_t size_cap = std::numeric_limits<int32_t>::max()) {
|
|
||||||
std::vector<std::vector<size_t>> strides;
|
std::vector<std::vector<size_t>> strides;
|
||||||
for (auto& x : xs) {
|
for (auto& x : xs) {
|
||||||
strides.emplace_back(x.strides());
|
strides.emplace_back(x.strides());
|
||||||
}
|
}
|
||||||
return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
|
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
@@ -72,110 +95,27 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
|||||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||||
}
|
}
|
||||||
|
|
||||||
// The single array version of the above.
|
template <typename stride_t>
|
||||||
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<int64_t>& strides,
|
|
||||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
|
||||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<size_t>& strides,
|
|
||||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
|
||||||
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
|
|
||||||
const array& a,
|
|
||||||
size_t size_cap = std::numeric_limits<int32_t>::max());
|
|
||||||
|
|
||||||
template <typename StrideT>
|
|
||||||
struct ContiguousIterator {
|
|
||||||
inline void step() {
|
|
||||||
int dims = shape_.size();
|
|
||||||
if (dims == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
int i = dims - 1;
|
|
||||||
while (pos_[i] == (shape_[i] - 1) && i > 0) {
|
|
||||||
pos_[i] = 0;
|
|
||||||
loc -= (shape_[i] - 1) * strides_[i];
|
|
||||||
i--;
|
|
||||||
}
|
|
||||||
pos_[i]++;
|
|
||||||
loc += strides_[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
void seek(StrideT n) {
|
|
||||||
loc = 0;
|
|
||||||
for (int i = shape_.size() - 1; i >= 0; --i) {
|
|
||||||
auto q_and_r = ldiv(n, shape_[i]);
|
|
||||||
loc += q_and_r.rem * strides_[i];
|
|
||||||
pos_[i] = q_and_r.rem;
|
|
||||||
n = q_and_r.quot;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void reset() {
|
|
||||||
loc = 0;
|
|
||||||
std::fill(pos_.begin(), pos_.end(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
ContiguousIterator() {};
|
|
||||||
|
|
||||||
explicit ContiguousIterator(const array& a)
|
|
||||||
: shape_(a.shape()), strides_(a.strides()) {
|
|
||||||
if (!shape_.empty()) {
|
|
||||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
|
||||||
pos_ = std::vector<int>(shape_.size(), 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit ContiguousIterator(
|
|
||||||
const std::vector<int>& shape,
|
|
||||||
const std::vector<StrideT>& strides,
|
|
||||||
int dims)
|
|
||||||
: shape_(shape.begin(), shape.begin() + dims),
|
|
||||||
strides_(strides.begin(), strides.begin() + dims) {
|
|
||||||
if (!shape_.empty()) {
|
|
||||||
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
|
||||||
pos_ = std::vector<int>(shape_.size(), 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
StrideT loc{0};
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<int> shape_;
|
|
||||||
std::vector<StrideT> strides_;
|
|
||||||
std::vector<int> pos_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename StrideT>
|
|
||||||
inline auto check_contiguity(
|
inline auto check_contiguity(
|
||||||
const std::vector<int>& shape,
|
const std::vector<int>& shape,
|
||||||
const std::vector<StrideT>& strides) {
|
const std::vector<stride_t>& strides) {
|
||||||
size_t no_broadcast_data_size = 1;
|
size_t data_size = 1;
|
||||||
size_t f_stride = 1;
|
size_t f_stride = 1;
|
||||||
size_t b_stride = 1;
|
size_t b_stride = 1;
|
||||||
bool is_row_contiguous = true;
|
bool is_row_contiguous = true;
|
||||||
bool is_col_contiguous = true;
|
bool is_col_contiguous = true;
|
||||||
|
|
||||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||||
is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||||
is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||||
f_stride *= shape[i];
|
f_stride *= shape[i];
|
||||||
b_stride *= shape[ri];
|
b_stride *= shape[ri];
|
||||||
if (strides[i] > 0) {
|
if (strides[i] > 0) {
|
||||||
no_broadcast_data_size *= shape[i];
|
data_size *= shape[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_tuple(
|
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
|
||||||
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool is_donatable(const array& in, const array& out) {
|
|
||||||
constexpr size_t donation_extra = 16384;
|
|
||||||
|
|
||||||
return in.is_donatable() && in.itemsize() == out.itemsize() &&
|
|
||||||
in.buffer_size() <= out.nbytes() + donation_extra;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user