mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 00:54:37 +08:00
Compare commits
1 Commits
simple-gem
...
winograd_q
Author | SHA1 | Date | |
---|---|---|---|
![]() |
f14b4d72de |
@@ -7,9 +7,15 @@ parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
weekly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
@@ -18,8 +24,8 @@ jobs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -32,7 +38,7 @@ jobs:
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
pip install . -v
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
@@ -64,9 +70,9 @@ jobs:
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -78,36 +84,34 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv
|
||||
uv pip install cmake
|
||||
uv pip install -e ".[dev]" -v
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
python -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
python3 -m unittest discover python/tests -v
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
mkdir -p build && cd build
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
- run:
|
||||
@@ -118,63 +122,57 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
default: "15.2.0"
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
||||
brew install openmpi uv
|
||||
brew install python@3.9
|
||||
brew install openmpi
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
pip install unittest-xml-reporting
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv --python 3.9
|
||||
uv pip install \
|
||||
nanobind==2.4.0 \
|
||||
cmake \
|
||||
numpy \
|
||||
torch \
|
||||
tensorflow \
|
||||
unittest-xml-reporting
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
uv pip install -e . -v
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
uv pip install typing_extensions
|
||||
uv run --no-project setup.py generate_stubs
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
source env/bin/activate
|
||||
cd examples/extensions
|
||||
uv pip install -r requirements.txt
|
||||
uv run --no-project setup.py build_ext --inplace
|
||||
uv run --no-project python test.py
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
source env/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
@@ -183,7 +181,7 @@ jobs:
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
source env/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
@@ -195,60 +193,13 @@ jobs:
|
||||
- run:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
uv pip install -e .
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
uv run --no-project python -m xmlrunner discover \
|
||||
-v python/tests \
|
||||
-o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
parameters:
|
||||
image_date:
|
||||
type: string
|
||||
default: "2023.11.1"
|
||||
machine:
|
||||
image: "linux-cuda-12:<< parameters.image_date >>"
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- cuda-<< parameters.image_date >>-{{ arch }}-
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
||||
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
||||
rm -rf ccache-4.11.3-linux-x86_64
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
uv venv
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
uv pip install -e ".[dev]" -v
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source .venv/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
- run:
|
||||
name: CCache report
|
||||
command: |
|
||||
ccache --show-stats
|
||||
ccache --zero-stats
|
||||
ccache --max-size 400MB
|
||||
ccache --cleanup
|
||||
- save_cache:
|
||||
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
||||
paths:
|
||||
- /home/circleci/.cache/ccache
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
@@ -257,18 +208,13 @@ jobs:
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
default: "15.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: m2pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -289,30 +235,22 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.9", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
@@ -329,100 +267,52 @@ jobs:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
build_env:
|
||||
extra_env:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
default: "DEV_RELEASE=1"
|
||||
docker:
|
||||
- image: ubuntu:20.04
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
PYTHON=python<< parameters.python_version >>
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
apt-get update
|
||||
apt-get upgrade -y
|
||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||
apt-get install -y apt-utils
|
||||
apt-get install -y software-properties-common
|
||||
add-apt-repository -y ppa:deadsnakes/ppa
|
||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
apt-get install -y build-essential git
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.build_env >> pip install ".[dev]" -v
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||
bash python/scripts/repair_linux.sh
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.9", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload packages
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
build_cuda_release:
|
||||
parameters:
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
steps:
|
||||
- checkout
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
- run:
|
||||
name: Build wheel
|
||||
name: Upload package
|
||||
command: |
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
python -m build -w
|
||||
bash python/scripts/repair_cuda.sh
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
twine upload wheelhouse/*.whl
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
@@ -434,23 +324,21 @@ workflows:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
and:
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
@@ -462,70 +350,8 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
@@ -533,25 +359,6 @@ workflows:
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
- build_linux_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
|
||||
prb:
|
||||
when:
|
||||
@@ -567,14 +374,9 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
image_date: ["2023.11.1", "2025.05.1"]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -585,140 +387,27 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
- build_cuda_release
|
||||
|
||||
build_dev_release:
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
weekly_build:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.test_release >>
|
||||
- << pipeline.parameters.weekly_build >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "DEV_RELEASE=1"
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.linux_release >>
|
||||
jobs:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,7 +36,6 @@ share/python-wheels/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
uv.lock
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
@@ -1,16 +1,16 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v19.1.7
|
||||
rev: v19.1.4
|
||||
hooks:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 25.1.0
|
||||
rev: 24.10.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 6.0.0
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
|
@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
@@ -19,7 +19,6 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@@ -1,24 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.25)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
|
||||
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_major ${CMAKE_MATCH_1})
|
||||
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_minor ${CMAKE_MATCH_1})
|
||||
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_patch ${CMAKE_MATCH_1})
|
||||
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
||||
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
||||
else()
|
||||
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
||||
${MLX_VERSION})
|
||||
endif()
|
||||
|
||||
project(
|
||||
mlx
|
||||
LANGUAGES C CXX
|
||||
VERSION ${MLX_PROJECT_VERSION})
|
||||
project(mlx LANGUAGES C CXX)
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
@@ -34,18 +16,21 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.21.1)
|
||||
endif()
|
||||
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
|
||||
message(
|
||||
STATUS
|
||||
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
||||
@@ -66,17 +51,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
|
||||
if(MLX_USE_CCACHE)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
||||
endif()
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
@@ -93,10 +71,6 @@ if(MLX_BUILD_METAL)
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
@@ -173,7 +147,6 @@ if(MLX_BUILD_CPU)
|
||||
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(MLX_USE_ACCELERATE)
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
||||
# Download and build OpenBLAS from source code.
|
||||
@@ -226,13 +199,23 @@ else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
message(STATUS "Downloading json")
|
||||
FetchContent_Declare(
|
||||
json
|
||||
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
||||
FetchContent_MakeAvailable(json)
|
||||
target_include_directories(
|
||||
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
||||
find_package(MPI)
|
||||
if(MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET)
|
||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
elseif(MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||
else()
|
||||
set(MPI_FOUND FALSE)
|
||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
||||
@@ -240,19 +223,12 @@ target_include_directories(
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
# Do not add mlx_EXPORTS define for shared library.
|
||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||
|
||||
if(USE_SYSTEM_FMT)
|
||||
find_package(fmt REQUIRED)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
endif()
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
||||
|
||||
if(MLX_BUILD_PYTHON_BINDINGS)
|
||||
|
@@ -5,26 +5,26 @@ possible.
|
||||
|
||||
## Pull Requests
|
||||
|
||||
1. Fork and submit pull requests to the repo.
|
||||
1. Fork and submit pull requests to the repo.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
||||
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
||||
4. If you've changed APIs, update the documentation.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||
This should install hooks for running `black` and `clang-format` to ensure
|
||||
consistent style for C++ and python code.
|
||||
|
||||
|
||||
You can also run the formatters manually as follows:
|
||||
|
||||
```shell
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```shell
|
||||
black file.py
|
||||
```
|
||||
|
||||
|
||||
```
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```
|
||||
black file.py
|
||||
```
|
||||
|
||||
or run `pre-commit run --all-files` to check all files in the repo.
|
||||
|
||||
## Issues
|
||||
|
@@ -1,6 +1,4 @@
|
||||
include CMakeLists.txt
|
||||
include mlx.pc.in
|
||||
recursive-include mlx/ *
|
||||
include cmake/*
|
||||
include python/src/*
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
21
README.md
21
README.md
@@ -11,10 +11,10 @@ brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
||||
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
||||
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
the Python API. MLX has higher-level packages like `mlx.nn` and
|
||||
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
||||
more complex models.
|
||||
|
||||
@@ -68,23 +68,18 @@ in the documentation.
|
||||
|
||||
## Installation
|
||||
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
||||
macOS, run:
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||
|
||||
```bash
|
||||
**With `pip`**:
|
||||
|
||||
```
|
||||
pip install mlx
|
||||
```
|
||||
|
||||
To install the CUDA backend on Linux, run:
|
||||
**With `conda`**:
|
||||
|
||||
```bash
|
||||
pip install mlx[cuda]
|
||||
```
|
||||
|
||||
To install a CPU-only Linux package, run:
|
||||
|
||||
```bash
|
||||
pip install mlx[cpu]
|
||||
conda install -c conda-forge mlx
|
||||
```
|
||||
|
||||
Checkout the
|
||||
|
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
@@ -192,22 +192,6 @@ void time_reductions() {
|
||||
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
|
||||
auto indices = mx::array({1});
|
||||
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||
std::vector<int> axes{0};
|
||||
auto b = scatter(a, {indices}, updates, axes);
|
||||
mx::eval(b);
|
||||
|
||||
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||
TIME(max_along_0);
|
||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||
TIME(max_along_1);
|
||||
|
||||
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||
TIME(min_along_0);
|
||||
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||
TIME(min_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
|
@@ -5,7 +5,6 @@ import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.mps
|
||||
|
||||
|
||||
@@ -45,10 +44,8 @@ def bench(f, *args):
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device == torch.device("mps"):
|
||||
if x.device != torch.device("cpu"):
|
||||
torch.mps.synchronize()
|
||||
elif x.device == torch.device("cuda"):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -102,14 +99,6 @@ def reduction(op, axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
@@ -351,11 +340,7 @@ if __name__ == "__main__":
|
||||
args.axis.pop(0)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "mps"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
if args.cpu:
|
||||
device = "cpu"
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
|
||||
types = args.dtype
|
||||
if not types:
|
||||
@@ -475,8 +460,5 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
@@ -1,107 +0,0 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = "float32"
|
||||
shapes = (
|
||||
(4, 32, 32, 21, 3, 3, 128),
|
||||
(4, 32, 32, 21, 3, 3, 37),
|
||||
(4, 32, 32, 370, 3, 3, 370),
|
||||
(4, 32, 32, 370, 7, 7, 128),
|
||||
(2, 320, 640, 21, 7, 7, 21),
|
||||
)
|
||||
for N, H, W, C, kh, kw, O in shapes:
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@@ -1,6 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from time import time
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
@@ -1,74 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_mm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = x @ w1.T
|
||||
x = x @ w2.T
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_mm()
|
@@ -1,84 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate(
|
||||
[
|
||||
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||
for i, j in enumerate(idx.tolist())
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_qmm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_qmm()
|
@@ -1,7 +1,5 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
@@ -12,71 +10,32 @@ def layer_norm(x, w, b, eps):
|
||||
x = x.astype(mx.float32)
|
||||
mu = mx.mean(x, -1, keepdims=True)
|
||||
v = mx.var(x, -1, keepdims=True)
|
||||
y = (x - mu) * mx.rsqrt(v + eps)
|
||||
if w is not None:
|
||||
y = y * w
|
||||
if b is not None:
|
||||
y = y + b
|
||||
return y
|
||||
return (x - mu) * mx.rsqrt(v + eps) * w + b
|
||||
|
||||
|
||||
def time_layer_norm(N, dt):
|
||||
L = 1024
|
||||
def time_layer_norm():
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(f, x, w, b):
|
||||
for _ in range(32):
|
||||
x = f(x, w, b)
|
||||
return x
|
||||
|
||||
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||
|
||||
def layer_norm_grad_loop(g, x, w, b):
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_grad_x_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||
print(dt, n)
|
||||
time_layer_norm(n, dt)
|
||||
time_layer_norm()
|
||||
|
@@ -9,10 +9,7 @@ def rms_norm(x, w, eps):
|
||||
ot = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
n = mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
y = (x * n).astype(ot)
|
||||
if w is not None:
|
||||
y = y * w
|
||||
return y
|
||||
return (x * n).astype(ot) * w
|
||||
|
||||
|
||||
def time_rms_norm():
|
||||
@@ -37,27 +34,6 @@ def time_rms_norm():
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x, w)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x, w)
|
||||
|
||||
f1 = lambda x, y: (rms_norm(x, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.rms_norm(x, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, y)
|
||||
|
||||
def rms_norm_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(rms_norm_loop, g1, x)
|
||||
time_fn(rms_norm_loop, g2, x)
|
||||
time_fn(rms_norm_loop, mx.compile(g1), x)
|
||||
time_fn(rms_norm_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_rms_norm()
|
||||
|
@@ -28,34 +28,11 @@ def bench(f, *args):
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||
|
||||
scale = 1.0 / math.sqrt(D)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
if mask is not None:
|
||||
if mask == "additive":
|
||||
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||
mask = mx.array(mask_np)
|
||||
elif mask == "bool":
|
||||
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||
mask = mx.array(mask_np)
|
||||
|
||||
return q_mx, k_mx, v_mx, scale, mask
|
||||
def mlx_sdpa_fused_inner(q, k, v, scale):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
||||
|
||||
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
@@ -64,7 +41,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
kL = k.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
@@ -72,27 +48,10 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
if mask == "causal":
|
||||
q_offset = max(0, kL - L)
|
||||
q_indices = mx.arange(q_offset, q_offset + L)
|
||||
k_indices = mx.arange(kL)
|
||||
mask = q_indices[:, None] >= k_indices[None]
|
||||
|
||||
if n_repeats > 1 and mask.ndim >= 3:
|
||||
if mask.shape[-3] == 1:
|
||||
mask = mx.expand_dims(mask, -3)
|
||||
else:
|
||||
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||
|
||||
if mask.dtype == mx.bool_:
|
||||
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
if f32softmax:
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
||||
else:
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
@@ -101,55 +60,74 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
return out
|
||||
|
||||
|
||||
def mlx_fused_attn(q, k, v, scale, mask):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||
if transpose:
|
||||
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||
else:
|
||||
return f(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||
def mlx_spda_unfused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||
):
|
||||
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||
def mlx_spda_fused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
||||
shape_q = (
|
||||
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
||||
)
|
||||
shape_kv = (
|
||||
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
||||
)
|
||||
|
||||
time_mlx_unfused = bench(
|
||||
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
time_mlx_fused = bench(
|
||||
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
|
||||
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||
o_mlx_unfused = do_attention(
|
||||
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
scale = math.sqrt(1.0 / head_dim)
|
||||
|
||||
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
|
||||
if transpose:
|
||||
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
||||
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
||||
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
||||
|
||||
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
||||
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
||||
print(
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
)
|
||||
|
||||
return time_mlx_fused, time_mlx_unfused
|
||||
@@ -173,51 +151,39 @@ if __name__ == "__main__":
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 256, 256, 64, 32, 32),
|
||||
( 1, 512, 512, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 8),
|
||||
( 1, 2048, 2048, 64, 32, 8),
|
||||
( 1, 4096, 4096, 64, 32, 8),
|
||||
( 1, 1024, 1024, 64, 32, 32),
|
||||
( 1, 2048, 2048, 64, 32, 32),
|
||||
( 1, 4096, 4096, 64, 32, 32),
|
||||
)
|
||||
|
||||
shapes_80 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 80, 32, 8),
|
||||
( 1, 2048, 2048, 80, 32, 8),
|
||||
( 1, 4096, 4096, 80, 32, 8),
|
||||
( 1, 1024, 1024, 80, 32, 32),
|
||||
( 1, 2048, 2048, 80, 32, 32),
|
||||
( 1, 4096, 4096, 80, 32, 32),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 8),
|
||||
( 1, 2048, 2048, 128, 32, 8),
|
||||
( 1, 4096, 4096, 128, 32, 8),
|
||||
( 1, 1024, 1024, 128, 32, 32),
|
||||
( 1, 2048, 2048, 128, 32, 32),
|
||||
( 1, 4096, 4096, 128, 32, 32),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
masks = [None, "bool", "causal"]
|
||||
|
||||
print(
|
||||
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||
)
|
||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||
for mask_in in masks:
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B,
|
||||
qsl,
|
||||
ksl,
|
||||
head_dim,
|
||||
n_q_heads,
|
||||
n_kv_heads,
|
||||
dtype,
|
||||
transpose,
|
||||
mask_in,
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
@@ -8,44 +8,30 @@ L = 16384
|
||||
H = 32
|
||||
H_k = H // 4
|
||||
D = 128
|
||||
V = 128
|
||||
dtype = mx.float16
|
||||
loops = 10
|
||||
|
||||
|
||||
def upproject(x, w):
|
||||
if w is None:
|
||||
return x
|
||||
else:
|
||||
return x @ w.T
|
||||
|
||||
|
||||
def attention(q, k, v, mask=None, w=None):
|
||||
def attention(q, k, v):
|
||||
def _sdpa(q, k, v):
|
||||
B, Hq, L, D = q.shape
|
||||
_, Hk, S, _ = k.shape
|
||||
_, _, _, V = v.shape
|
||||
q = q.reshape(B, Hk, Hq // Hk, L, D)
|
||||
k = k[:, :, None, :, :]
|
||||
v = v[:, :, None, :, :]
|
||||
s = q @ k.transpose(0, 1, 2, 4, 3)
|
||||
if mask is not None:
|
||||
m = mx.broadcast_to(mask, (B, Hq, L, S)).reshape(B, Hk, Hq // Hk, L, S)
|
||||
s = mx.where(m, s, mx.finfo(s.dtype).min)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ v
|
||||
return o.reshape(B, Hq, L, V)
|
||||
return o.reshape(B, Hq, L, D)
|
||||
|
||||
for i in range(loops):
|
||||
q = _sdpa(q, k, v)
|
||||
q = upproject(q, w)
|
||||
return q
|
||||
|
||||
|
||||
def sdpa(q, k, v, mask=None, w=None):
|
||||
def sdpa(q, k, v):
|
||||
for i in range(loops):
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
q = upproject(q, w)
|
||||
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
|
||||
return q
|
||||
|
||||
|
||||
@@ -53,43 +39,20 @@ def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mx.eval(q, k, v, w)
|
||||
time_fn(attention, q, k, v, w=w)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mx.eval(q, k, v)
|
||||
time_fn(attention, q, k, v)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mx.eval(q, k, v, w)
|
||||
time_fn(sdpa, q, k, v, w=w)
|
||||
|
||||
|
||||
def time_self_attention_sdpa_with_mask():
|
||||
mx.random.seed(3)
|
||||
q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
|
||||
k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, V)).astype(dtype)
|
||||
w = mx.random.uniform(shape=(D, V)).astype(dtype) if V != D else None
|
||||
mask = mx.full((L,), True)
|
||||
mask[L // 2 :] = False
|
||||
mx.eval(q, k, v, mask, w)
|
||||
|
||||
def sdpa_mask(*args):
|
||||
return sdpa(*args, mask=mask, w=w)
|
||||
|
||||
def attention_mask(*args):
|
||||
return attention(*args, mask=mask, w=w)
|
||||
|
||||
time_fn(attention_mask, q, k, v)
|
||||
time_fn(sdpa_mask, q, k, v)
|
||||
v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
|
||||
mx.eval(q, k, v)
|
||||
time_fn(sdpa, q, k, v)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
||||
time_self_attention_sdpa_with_mask()
|
||||
|
@@ -51,20 +51,6 @@ def time_maximum():
|
||||
time_fn(mx.maximum, a, b)
|
||||
|
||||
|
||||
def time_max():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.max, a, 0)
|
||||
|
||||
|
||||
def time_min():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.min, a, 0)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@@ -122,8 +108,6 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_min()
|
||||
time_max()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
time_negative()
|
||||
|
@@ -1,55 +0,0 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
rank = mx.distributed.init().rank()
|
||||
|
||||
|
||||
def timeit(fn, a):
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(a))
|
||||
|
||||
its = 10
|
||||
tic = time.perf_counter()
|
||||
for _ in range(its):
|
||||
mx.eval(fn(a))
|
||||
toc = time.perf_counter()
|
||||
ms = 1000 * (toc - tic) / its
|
||||
return ms
|
||||
|
||||
|
||||
def all_reduce_benchmark():
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
|
||||
its_per_eval = 100
|
||||
|
||||
def fn(x):
|
||||
for _ in range(its_per_eval):
|
||||
x = mx.distributed.all_sum(x)
|
||||
x = x - 1
|
||||
return x
|
||||
|
||||
ms = timeit(fn, a) / its_per_eval
|
||||
if rank == 0:
|
||||
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
|
||||
|
||||
|
||||
def all_gather_benchmark():
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
its_per_eval = 100
|
||||
|
||||
def fn(x):
|
||||
for _ in range(its_per_eval):
|
||||
x = mx.distributed.all_gather(x)[0]
|
||||
return x
|
||||
|
||||
ms = timeit(fn, a) / its_per_eval
|
||||
if rank == 0:
|
||||
print(f"All gather: time per iteration {ms:.6f} (ms)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_reduce_benchmark()
|
||||
all_gather_benchmark()
|
@@ -1,7 +1,5 @@
|
||||
include(CMakeParseArguments)
|
||||
|
||||
# clang format off
|
||||
#
|
||||
# ##############################################################################
|
||||
# Build metal library
|
||||
#
|
||||
@@ -11,14 +9,11 @@ include(CMakeParseArguments)
|
||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||
# files (like headers)
|
||||
#
|
||||
# clang format on
|
||||
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
@@ -26,11 +21,7 @@ macro(mlx_build_metallib)
|
||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||
-frecord-sources)
|
||||
endif()
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
|
@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
|
||||
CREATE_SUBDIRS = NO
|
||||
FULL_PATH_NAMES = YES
|
||||
RECURSIVE = YES
|
||||
GENERATE_HTML = NO
|
||||
GENERATE_HTML = YES
|
||||
GENERATE_LATEX = NO
|
||||
GENERATE_XML = YES
|
||||
XML_PROGRAMLISTING = YES
|
||||
|
@@ -1,5 +1,4 @@
|
||||
sphinx
|
||||
breathe
|
||||
sphinx-book-theme
|
||||
sphinx-copybutton
|
||||
mlx
|
||||
|
@@ -10,7 +10,7 @@ import mlx.core as mx
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "MLX"
|
||||
copyright = "2023, Apple"
|
||||
copyright = "2023, MLX Contributors"
|
||||
author = "MLX Contributors"
|
||||
version = ".".join(mx.__version__.split(".")[:3])
|
||||
release = version
|
||||
@@ -18,7 +18,6 @@ release = version
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
extensions = [
|
||||
"sphinx_copybutton",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
|
@@ -8,26 +8,23 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@@ -42,13 +39,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Every time you make a kernel, a new Metal library is created and possibly
|
||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||
:func:`fast.metal_kernel` and then use it many times.
|
||||
|
||||
.. note::
|
||||
Only pass the body of the Metal kernel in ``source``. The function
|
||||
signature is generated automatically.
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
@@ -86,52 +78,44 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||
dimension should be less than or equal to the corresponding grid dimension.
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||
generated code for debugging purposes.
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||
is ``True`` by default. This will copy the array inputs if needed
|
||||
before the kernel is launched to ensure that the memory layout is row
|
||||
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||
have to worry about gaps or the ordering of the dims when indexing.
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
|
||||
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||
the right elements for each thread.
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||
relying on a copy from ``ensure_row_contiguous``:
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@@ -139,6 +123,7 @@ relying on a copy from ``ensure_row_contiguous``:
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes=[a.shape],
|
||||
output_dtypes=[a.dtype],
|
||||
ensure_row_contiguous=False,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
@@ -157,139 +142,137 @@ We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
return output
|
||||
|
||||
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
@@ -298,11 +281,11 @@ On an M1 Max, we see a big performance improvement:
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||
define its custom vjp transform so MLX can differentiate it.
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra :func:`fast.metal_kernel` features:
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
@@ -316,129 +299,128 @@ We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
|
@@ -22,12 +22,12 @@ You can do that in MLX directly:
|
||||
This function performs that operation while leaving the implementation and
|
||||
function transformations to MLX.
|
||||
|
||||
However, you may want to customize the underlying implementation, perhaps to
|
||||
make it faster. In this tutorial we will go through adding custom extensions.
|
||||
It will cover:
|
||||
However you may need to customize the underlying implementation, perhaps to
|
||||
make it faster or for custom differentiation. In this tutorial we will go
|
||||
through adding custom extensions. It will cover:
|
||||
|
||||
* The structure of the MLX library.
|
||||
* Implementing a CPU operation.
|
||||
* Implementing a CPU operation that redirects to Accelerate_ when appropriate.
|
||||
* Implementing a GPU operation using metal.
|
||||
* Adding the ``vjp`` and ``jvp`` function transformation.
|
||||
* Building a custom extension and binding it to python.
|
||||
@@ -45,7 +45,7 @@ Operations
|
||||
Operations are the front-end functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
|
||||
|
||||
We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and
|
||||
``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
|
||||
C++:
|
||||
|
||||
@@ -55,7 +55,7 @@ C++:
|
||||
* Scale and sum two vectors element-wise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Use NumPy-style broadcasting between x and y
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
@@ -66,7 +66,7 @@ C++:
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
The simplest way to implement this is with existing operations:
|
||||
The simplest way to this operation is in terms of existing operations:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -93,9 +93,9 @@ Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create output arrays given input arrays. Further, a
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
@@ -128,7 +128,7 @@ more concrete:
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
@@ -138,13 +138,13 @@ more concrete:
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
@@ -153,6 +153,9 @@ more concrete:
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class. The
|
||||
@@ -185,7 +188,7 @@ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = issubdtype(promoted_dtype, float32)
|
||||
auto out_dtype = is_floating_point(promoted_dtype)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
@@ -231,57 +234,49 @@ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
|
||||
Implementing the CPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's start by implementing :meth:`Axpby::eval_cpu`.
|
||||
Let's start by implementing a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
|
||||
The method will go over each element of the output array, find the
|
||||
Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(y);
|
||||
encoder.set_output_array(out);
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
// Launch the CPU kernel
|
||||
encoder.dispatch([x_ptr = x.data<T>(),
|
||||
y_ptr = y.data<T>(),
|
||||
out_ptr = out.data<T>(),
|
||||
size = out.size(),
|
||||
shape = out.shape(),
|
||||
x_strides = x.strides(),
|
||||
y_strides = y.strides(),
|
||||
alpha_,
|
||||
beta_]() {
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
});
|
||||
}
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
|
||||
Our implementation should work for all incoming floating point arrays.
|
||||
Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
@@ -289,32 +284,112 @@ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[Axpby] Only supports floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
This is good as a fallback implementation. We can use the ``axpby`` routine
|
||||
provided by the Accelerate_ framework for a faster implementation in certain
|
||||
cases:
|
||||
|
||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||
floats. We can only use it for ``float32`` types.
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all
|
||||
elements have fixed strides between them. We only direct to Accelerate
|
||||
if both ``x`` and ``y`` are row contiguous or column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` in-place.
|
||||
MLX expects to write the output to a new array. We must copy the elements
|
||||
of ``y`` into the output and use that as an input to ``axpby``.
|
||||
|
||||
Let's write an implementation that uses Accelerate in the right conditions.
|
||||
It allocates data for the output, copies ``y`` into it, and then calls the
|
||||
:func:`catlas_saxpby` from accelerate.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
For inputs that do not fit the criteria for accelerate, we fall back to
|
||||
:meth:`Axpby::eval`. With this in mind, let's finish our
|
||||
:meth:`Axpby::eval_cpu`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common back-end if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
|
||||
you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here.
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
|
||||
Implementing the GPU Back-end
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@@ -391,17 +466,17 @@ below.
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::stream kname;
|
||||
kname = "axpby_general_" + type_to_name(out);
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname, lib);
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -469,7 +544,7 @@ one we just defined:
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can be built with ops
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
@@ -481,7 +556,7 @@ one we just defined:
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If argnums = {0, 1}, we take contributions from both
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
@@ -735,7 +810,7 @@ Let's look at a simple script and its results:
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
@@ -743,13 +818,13 @@ Output:
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c is correct: True
|
||||
c correctness: True
|
||||
|
||||
Results
|
||||
^^^^^^^
|
||||
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we first defined.
|
||||
with the naive :meth:`simple_axpby` we first defined on the CPU.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -757,11 +832,13 @@ with the naive :meth:`simple_axpby` we first defined.
|
||||
from mlx_sample_extensions import axpby
|
||||
import time
|
||||
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
M = 4096
|
||||
N = 4096
|
||||
M = 256
|
||||
N = 512
|
||||
|
||||
x = mx.random.normal((M, N))
|
||||
y = mx.random.normal((M, N))
|
||||
@@ -772,24 +849,24 @@ with the naive :meth:`simple_axpby` we first defined.
|
||||
|
||||
def bench(f):
|
||||
# Warm up
|
||||
for i in range(5):
|
||||
for i in range(100):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
|
||||
# Timed run
|
||||
s = time.time()
|
||||
for i in range(100):
|
||||
for i in range(5000):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
e = time.time()
|
||||
return 1000 * (e - s) / 100
|
||||
return e - s
|
||||
|
||||
simple_time = bench(simple_axpby)
|
||||
custom_time = bench(axpby)
|
||||
|
||||
print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
|
||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||
|
||||
The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
|
||||
The results are ``Simple axpby: 0.114 s | Custom axpby: 0.109 s``. We see
|
||||
modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations, in
|
||||
|
@@ -70,8 +70,6 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/cuda
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
|
@@ -13,7 +13,7 @@ silicon computer is
|
||||
|
||||
pip install mlx
|
||||
|
||||
To install from PyPI your system must meet the following requirements:
|
||||
To install from PyPI you must meet the following requirements:
|
||||
|
||||
- Using an M series chip (Apple silicon)
|
||||
- Using a native Python >= 3.9
|
||||
@@ -23,39 +23,12 @@ To install from PyPI your system must meet the following requirements:
|
||||
MLX is only available on devices running macOS >= 13.5
|
||||
It is highly recommended to use macOS 14 (Sonoma)
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
MLX has a CUDA backend which you can install with:
|
||||
MLX is also available on conda-forge. To install MLX with conda do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx[cuda]
|
||||
|
||||
To install the CUDA package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Nvidia architecture >= SM 7.0 (Volta)
|
||||
- Nvidia driver >= 550.54.14
|
||||
- CUDA toolkit >= 12.0
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.9
|
||||
|
||||
|
||||
CPU-only (Linux)
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
For a CPU-only version of MLX that runs on Linux use:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install mlx[cpu]
|
||||
|
||||
To install the CPU-only package from PyPi your system must meet the following
|
||||
requirements:
|
||||
|
||||
- Linux distribution with glibc >= 2.35
|
||||
- Python >= 3.9
|
||||
conda install conda-forge::mlx
|
||||
|
||||
|
||||
Troubleshooting
|
||||
@@ -92,8 +65,6 @@ Build Requirements
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
|
||||
.. _python install:
|
||||
|
||||
To build and install the MLX python library from source, first, clone MLX from
|
||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
@@ -105,20 +76,20 @@ Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install .
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
||||
|
||||
For developing, install the package with development dependencies, and use an
|
||||
editable install:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install -e ".[dev]"
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
||||
|
||||
Once the development dependencies are installed, you can build faster with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python setup.py build_ext --inplace
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
||||
|
||||
Run the tests with:
|
||||
|
||||
@@ -136,8 +107,6 @@ IDE:
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
.. _cpp install:
|
||||
|
||||
Currently, MLX must be built and installed from source.
|
||||
|
||||
Similarly to the python library, to build and install the MLX C++ library start
|
||||
@@ -216,7 +185,6 @@ should point to the path to the built metal library.
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -245,50 +213,6 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Linux
|
||||
^^^^^
|
||||
|
||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||
For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
apt-get update -y
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
From here follow the instructions to install either the :ref:`Python <python
|
||||
install>` or :ref:`C++ <cpp install>` APIs.
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
apt-get update -y
|
||||
apt-get -y install cuda-toolkit-12-9
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
|
||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||
|
||||
To build the C++ package run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
@@ -19,8 +19,6 @@ Array
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
array.real
|
||||
array.imag
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
@@ -40,7 +38,6 @@ Array
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logcumsumexp
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
|
@@ -1,9 +0,0 @@
|
||||
CUDA
|
||||
=====
|
||||
|
||||
.. currentmodule:: mlx.core.cuda
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
is_available
|
@@ -51,20 +51,11 @@ The default floating point type is ``float32`` and the default integer type is
|
||||
* - ``float32``
|
||||
- 4
|
||||
- 32-bit float
|
||||
* - ``float64``
|
||||
- 4
|
||||
- 64-bit double
|
||||
* - ``complex64``
|
||||
- 8
|
||||
- 64-bit complex float
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Arrays with type ``float64`` only work with CPU operations. Using
|
||||
``float64`` arrays on the GPU will result in an exception.
|
||||
|
||||
|
||||
Data type are aranged in a hierarchy. See the :obj:`DtypeCategory` object
|
||||
documentation for more information. Use :func:`issubdtype` to determine if one
|
||||
``dtype`` (or category) is a subtype of another category.
|
||||
|
@@ -13,4 +13,3 @@ Fast
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
metal_kernel
|
||||
cuda_kernel
|
||||
|
@@ -20,5 +20,3 @@ FFT
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
||||
fftshift
|
||||
ifftshift
|
||||
|
@@ -5,8 +5,8 @@ Linear Algebra
|
||||
|
||||
.. currentmodule:: mlx.core.linalg
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
inv
|
||||
tri_inv
|
||||
@@ -16,12 +16,5 @@ Linear Algebra
|
||||
cross
|
||||
qr
|
||||
svd
|
||||
eigvals
|
||||
eig
|
||||
eigvalsh
|
||||
eigh
|
||||
lu
|
||||
lu_factor
|
||||
pinv
|
||||
solve
|
||||
solve_triangular
|
||||
|
@@ -1,16 +0,0 @@
|
||||
Memory Management
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
@@ -8,5 +8,13 @@ Metal
|
||||
|
||||
is_available
|
||||
device_info
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@@ -174,7 +174,6 @@ In detail:
|
||||
|
||||
value_and_grad
|
||||
quantize
|
||||
average_gradients
|
||||
|
||||
.. toctree::
|
||||
|
||||
|
@@ -32,16 +32,13 @@ Operations
|
||||
atleast_2d
|
||||
atleast_3d
|
||||
bitwise_and
|
||||
bitwise_invert
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_arrays
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
contiguous
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
@@ -103,7 +100,6 @@ Operations
|
||||
log10
|
||||
log1p
|
||||
logaddexp
|
||||
logcumsumexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
@@ -149,8 +145,6 @@ Operations
|
||||
sign
|
||||
sin
|
||||
sinh
|
||||
slice
|
||||
slice_update
|
||||
softmax
|
||||
sort
|
||||
split
|
||||
|
@@ -51,14 +51,14 @@ the saved state. Here's a simple example:
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Save the state
|
||||
state = tree_flatten(optimizer.state, destination={})
|
||||
mx.save_safetensors("optimizer.safetensors", state)
|
||||
state = tree_flatten(optimizer.state)
|
||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||
|
||||
# Later on, for example when loading from a checkpoint,
|
||||
# recreate the optimizer and load the state
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
state = tree_unflatten(mx.load("optimizer.safetensors"))
|
||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||
optimizer.state = state
|
||||
|
||||
Note, not every optimizer configuation parameter is saved in the state. For
|
||||
|
@@ -18,5 +18,3 @@ Common Optimizers
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
Muon
|
||||
|
@@ -9,7 +9,6 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
async_eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
|
@@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence,
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z)
|
||||
return mx.exp(z), state
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
|
@@ -5,27 +5,21 @@ Distributed Communication
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
MLX supports distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. At the
|
||||
moment we support two different communication backends:
|
||||
|
||||
* `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ a
|
||||
full-featured and mature distributed communications library
|
||||
* A **ring** backend of our own that uses native TCP sockets and should be
|
||||
faster for thunderbolt connections.
|
||||
|
||||
The list of all currently supported operations and their documentation can be
|
||||
seen in the :ref:`API docs<distributed>`.
|
||||
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
||||
provide distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. You can
|
||||
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
||||
|
||||
.. note::
|
||||
Some operations may not be supported or not as fast as they should be.
|
||||
A lot of operations may not be supported or not as fast as they should be.
|
||||
We are adding more and tuning the ones we have as we are figuring out the
|
||||
best way to do distributed computing on Macs using MLX.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
A distributed program in MLX is as simple as:
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. The minimal distributed program in MLX is as simple as:
|
||||
|
||||
.. code:: python
|
||||
|
||||
@@ -36,79 +30,74 @@ A distributed program in MLX is as simple as:
|
||||
print(world.rank(), x)
|
||||
|
||||
The program above sums the array ``mx.ones(10)`` across all
|
||||
distributed processes. However, when this script is run with ``python`` only
|
||||
one process is launched and no distributed communication takes place. Namely,
|
||||
all operations in ``mx.distributed`` are noops when the distributed group has a
|
||||
size of one. This property allows us to avoid code that checks if we are in a
|
||||
distributed setting similar to the one below:
|
||||
distributed processes. If simply run with ``python``, however, only one
|
||||
process is launched and no distributed communication takes place.
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
x = ...
|
||||
world = mx.distributed.init()
|
||||
# No need for the check we can simply do x = mx.distributed.all_sum(x)
|
||||
if world.size() > 1:
|
||||
x = mx.distributed.all_sum(x)
|
||||
|
||||
Running Distributed Programs
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MLX provides ``mlx.launch`` a helper script to launch distributed programs.
|
||||
Continuing with our initial example we can run it on localhost with 4 processes using
|
||||
To launch the program in distributed mode we need to use ``mpirun`` or
|
||||
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
||||
following:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mlx.launch -n 4 my_script.py
|
||||
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
$ mpirun -np 2 python test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
|
||||
We can also run it on some remote hosts by providing their IPs (provided that
|
||||
the script exists on all hosts and they are reachable by ssh)
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
---------------
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mlx.launch --hosts ip1,ip2,ip3,ip4 my_script.py
|
||||
3 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
2 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
1 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
0 array([4, 4, 4, ..., 4, 4, 4], dtype=float32)
|
||||
$ conda install openmpi
|
||||
|
||||
Consult the dedicated :doc:`usage guide<launching_distributed>` for more
|
||||
information on using ``mlx.launch``.
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
||||
|
||||
Selecting Backend
|
||||
^^^^^^^^^^^^^^^^^
|
||||
.. code:: shell
|
||||
|
||||
You can select the backend you want to use when calling :func:`init` by passing
|
||||
one of ``{'any', 'ring', 'mpi'}``. When passing ``any``, MLX will try to
|
||||
initialize the ``ring`` backend and if it fails the ``mpi`` backend. If they
|
||||
both fail then a singleton group is created.
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
-----------------------
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
||||
full path to force all machines to use a specific path.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
|
||||
.. note::
|
||||
After a distributed backend is successfully initialized :func:`init` will
|
||||
return **the same backend** if called without arguments or with backend set to
|
||||
``any``.
|
||||
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
||||
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
||||
|
||||
The following examples aim to clarify the backend initialization logic in MLX:
|
||||
An easy way to pass the host names to MPI is using a host file. A host file
|
||||
looks like the following, where ``host1`` and ``host2`` should be the fully
|
||||
qualified domain names or IPs for these hosts.
|
||||
|
||||
.. code:: python
|
||||
.. code::
|
||||
|
||||
# Case 1: Initialize MPI regardless if it was possible to initialize the ring backend
|
||||
world = mx.distributed.init(backend="mpi")
|
||||
world2 = mx.distributed.init() # subsequent calls return the MPI backend!
|
||||
host1 slots=1
|
||||
host2 slots=1
|
||||
|
||||
# Case 2: Initialize any backend
|
||||
world = mx.distributed.init(backend="any") # equivalent to no arguments
|
||||
world2 = mx.distributed.init() # same as above
|
||||
|
||||
# Case 3: Initialize both backends at the same time
|
||||
world_mpi = mx.distributed.init(backend="mpi")
|
||||
world_ring = mx.distributed.init(backend="ring")
|
||||
world_any = mx.distributed.init() # same as MPI because it was initialized first!
|
||||
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||
process per host. The hostfile also needs to contain the current
|
||||
host if you want to run on the local host. Passing the host file to
|
||||
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
||||
|
||||
Training Example
|
||||
----------------
|
||||
@@ -166,179 +155,13 @@ everything else remaining the same.
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
Utilizing ``nn.average_gradients``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Tuning All Reduce
|
||||
-----------------
|
||||
|
||||
Although the code example above works correctly; it performs one communication
|
||||
per gradient. It is significantly more efficient to aggregate several gradients
|
||||
together and perform fewer communication steps.
|
||||
We are working on improving the performance of all reduce on MLX but for now
|
||||
the two main things one can do to extract the most out of distributed training with MLX are:
|
||||
|
||||
This is the purpose of :func:`mlx.nn.average_gradients`. The final code looks
|
||||
almost identical to the example above:
|
||||
|
||||
.. code:: python
|
||||
|
||||
model = ...
|
||||
optimizer = ...
|
||||
dataset = ...
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = mlx.nn.average_gradients(grads) # <---- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(model, x, y)
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
|
||||
Getting Started with MPI
|
||||
------------------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. Launching distributed MLX programs that use MPI can be done with
|
||||
``mpirun`` as expected. However, in the following examples we will be using
|
||||
``mlx.launch --backend mpi`` which takes care of some nuisances such as setting
|
||||
absolute paths for the ``mpirun`` executable and the ``libmpi.dyld`` shared
|
||||
library.
|
||||
|
||||
The simplest possible usage is the following which, assuming the minimal
|
||||
example in the beginning of this page, should result in:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mlx.launch --backend mpi -n 2 test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mlx.launch -n 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install conda-forge::openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun`` and it is
|
||||
done automatically by ``mlx.launch``.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
$ # or simply
|
||||
$ mlx.launch -n 2 test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
|
||||
Tuning MPI All Reduce
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. note::
|
||||
|
||||
For faster all reduce consider using the ring backend either with Thunderbolt
|
||||
connections or over Ethernet.
|
||||
|
||||
Configure MPI to use N tcp connections between each host to improve bandwidth
|
||||
by passing ``--mca btl_tcp_links N``.
|
||||
|
||||
Force MPI to use the most performant network interface by setting ``--mca
|
||||
btl_tcp_if_include <iface>`` where ``<iface>`` should be the interface you want
|
||||
to use.
|
||||
|
||||
Getting Started with Ring
|
||||
-------------------------
|
||||
|
||||
The ring backend does not depend on any third party library so it is always
|
||||
available. It uses TCP sockets so the nodes need to be reachable via a network.
|
||||
As the name suggests the nodes are connected in a ring which means that rank 1
|
||||
can only communicate with rank 0 and rank 2, rank 2 only with rank 1 and rank 3
|
||||
and so on and so forth. As a result :func:`send` and :func:`recv` with
|
||||
arbitrary sender and receiver is not supported in the ring backend.
|
||||
|
||||
Defining a Ring
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
The easiest way to define and use a ring is via a JSON hostfile and the
|
||||
``mlx.launch`` :doc:`helper script <launching_distributed>`. For each node one
|
||||
defines a hostname to ssh into to run commands on this node and one or more IPs
|
||||
that this node will listen to for connections.
|
||||
|
||||
For example the hostfile below defines a 4 node ring. ``hostname1`` will be
|
||||
rank 0, ``hostname2`` rank 1 etc.
|
||||
|
||||
.. code:: json
|
||||
|
||||
[
|
||||
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
|
||||
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
|
||||
{"ssh": "hostname3", "ips": ["123.123.123.3"]},
|
||||
{"ssh": "hostname4", "ips": ["123.123.123.4"]}
|
||||
]
|
||||
|
||||
Running ``mlx.launch --hostfile ring-4.json my_script.py`` will ssh into each
|
||||
node, run the script which will listen for connections in each of the provided
|
||||
IPs. Specifically, ``hostname1`` will connect to ``123.123.123.2`` and accept a
|
||||
connection from ``123.123.123.4`` and so on and so forth.
|
||||
|
||||
Thunderbolt Ring
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Although the ring backend can have benefits over MPI even for Ethernet, its
|
||||
main purpose is to use Thunderbolt rings for higher bandwidth communication.
|
||||
Setting up such thunderbolt rings can be done manually, but is a relatively
|
||||
tedious process. To simplify this, we provide the utility ``mlx.distributed_config``.
|
||||
|
||||
To use ``mlx.distributed_config`` your computers need to be accessible by ssh via
|
||||
Ethernet or Wi-Fi. Subsequently, connect them via thunderbolt cables and then call the
|
||||
utility as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.distributed_config --verbose --hosts host1,host2,host3,host4
|
||||
|
||||
By default the script will attempt to discover the thunderbolt ring and provide
|
||||
you with the commands to configure each node as well as the ``hostfile.json``
|
||||
to use with ``mlx.launch``. If password-less ``sudo`` is available on the nodes
|
||||
then ``--auto-setup`` can be used to configure them automatically.
|
||||
|
||||
To validate your connection without configuring anything
|
||||
``mlx.distributed_config`` can also plot the ring using DOT format.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.distributed_config --verbose --hosts host1,host2,host3,host4 --dot >ring.dot
|
||||
dot -Tpng ring.dot >ring.png
|
||||
open ring.png
|
||||
|
||||
If you want to go through the process manually, the steps are as follows:
|
||||
|
||||
* Disable the thunderbolt bridge interface
|
||||
* For the cable connecting rank ``i`` to rank ``i + 1`` find the interfaces
|
||||
corresponding to that cable in nodes ``i`` and ``i + 1``.
|
||||
* Set up a unique subnetwork connecting the two nodes for the corresponding
|
||||
interfaces. For instance if the cable corresponds to ``en2`` on node ``i``
|
||||
and ``en2`` also on node ``i + 1`` then we may assign IPs ``192.168.0.1`` and
|
||||
``192.168.0.2`` respectively to the two nodes. For more details you can see
|
||||
the commands prepared by the utility script.
|
||||
1. Perform a few large reductions instead of many small ones to improve
|
||||
bandwidth and latency
|
||||
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
||||
connections between each host to improve bandwidth
|
||||
|
@@ -7,17 +7,17 @@ Exporting Functions
|
||||
|
||||
MLX has an API to export and import functions to and from a file. This lets you
|
||||
run computations written in one MLX front-end (e.g. Python) in another MLX
|
||||
front-end (e.g. C++).
|
||||
front-end (e.g. C++).
|
||||
|
||||
This guide walks through the basics of the MLX export API with some examples.
|
||||
To see the full list of functions check-out the :ref:`API documentation
|
||||
<export>`.
|
||||
|
||||
Basics of Exporting
|
||||
Basics of Exporting
|
||||
-------------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
@@ -67,7 +67,7 @@ specified as variable positional arguments or as a tuple of arrays:
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(1.0)
|
||||
|
||||
|
||||
# Both arguments to fun are positional
|
||||
mx.export_function("add.mlxfn", fun, x, y)
|
||||
|
||||
@@ -133,7 +133,7 @@ parameters are also saved to the ``model.mlxfn`` file.
|
||||
For enclosed arrays inside an exported function, be extra careful to ensure
|
||||
they are evaluated. The computation graph that gets exported will include
|
||||
the computation that produces enclosed inputs.
|
||||
|
||||
|
||||
If the above example was missing ``mx.eval(model.parameters()``, the
|
||||
exported function would include the random initialization of the
|
||||
:obj:`mlx.nn.Module` parameters.
|
||||
@@ -150,8 +150,8 @@ parameters, pass them as inputs to the ``call`` wrapper:
|
||||
# Set the model's parameters to the input parameters
|
||||
model.update(tree_unflatten(list(params.items())))
|
||||
return model(x)
|
||||
|
||||
params = tree_flatten(model.parameters(), destination={})
|
||||
|
||||
params = dict(tree_flatten(model.parameters()))
|
||||
mx.export_function("model.mlxfn", call, (mx.zeros(4),), params)
|
||||
|
||||
|
||||
@@ -169,8 +169,8 @@ to export a function which can be used for inputs with variable shapes:
|
||||
|
||||
# Ok
|
||||
out, = imported_abs(mx.array(-1.0))
|
||||
|
||||
# Also ok
|
||||
|
||||
# Also ok
|
||||
out, = imported_abs(mx.array([-1.0, -2.0]))
|
||||
|
||||
With ``shapeless=False`` (which is the default), the second call to
|
||||
@@ -197,7 +197,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
||||
def fun(x, y=None):
|
||||
constant = mx.array(3.0)
|
||||
if y is not None:
|
||||
x += y
|
||||
x += y
|
||||
return x + constant
|
||||
|
||||
with mx.exporter("fun.mlxfn", fun) as exporter:
|
||||
@@ -215,7 +215,7 @@ a single file by creating an exporting context manager with :func:`exporter`:
|
||||
print(out)
|
||||
|
||||
In the above example the function constant data, (i.e. ``constant``), is only
|
||||
saved once.
|
||||
saved once.
|
||||
|
||||
Transformations with Imported Functions
|
||||
---------------------------------------
|
||||
@@ -238,7 +238,7 @@ on imported functions just like regular Python functions:
|
||||
# Prints: array(1, dtype=float32)
|
||||
print(dfdx(x))
|
||||
|
||||
# Compile the imported function
|
||||
# Compile the imported function
|
||||
mx.compile(imported_fun)
|
||||
# Prints: array(0, dtype=float32)
|
||||
print(compiled_fun(x)[0])
|
||||
@@ -275,7 +275,7 @@ Import and run the function in C++ with only a few lines of code:
|
||||
// Prints: array(2, dtype=float32)
|
||||
std::cout << outputs[0] << std::endl;
|
||||
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
Imported functions can be transformed in C++ just like in Python. Use
|
||||
``std::vector<mx::array>`` for positional arguments and ``std::map<std::string,
|
||||
mx::array>`` for keyword arguments when calling imported functions in C++.
|
||||
|
||||
|
@@ -107,16 +107,6 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> a[[0, 0]] = mx.array([4, 5])
|
||||
|
||||
The first element of ``a`` could be ``4`` or ``5``.
|
||||
|
||||
Transformations of functions which use in-place updates are allowed and work as
|
||||
expected. For example:
|
||||
|
||||
|
@@ -1,105 +0,0 @@
|
||||
:orphan:
|
||||
|
||||
.. _usage_launch_distributed:
|
||||
|
||||
Launching Distributed Programs
|
||||
==============================
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
Installing the MLX python package provides a helper script ``mlx.launch`` that
|
||||
can be used to run python scripts distributed on several nodes. It allows
|
||||
launching using either the MPI backend or the ring backend. See the
|
||||
:doc:`distributed docs <distributed>` for the different backends.
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
The minimal usage example of ``mlx.launch`` is simply
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.launch --hosts ip1,ip2 my_script.py
|
||||
|
||||
or for testing on localhost
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.launch -n 2 my_script.py
|
||||
|
||||
The ``mlx.launch`` command connects to the provided host and launches the input
|
||||
script on each host. It monitors each of the launched processes and terminates
|
||||
the rest if one of them fails unexpectedly or if ``mlx.launch`` is terminated.
|
||||
It also takes care of forwarding the output of each remote process to stdout
|
||||
and stderr respectively.
|
||||
|
||||
Providing Hosts
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Hosts can be provided as command line arguments, like above, but the way that
|
||||
allows to fully define a list of hosts is via a JSON hostfile. The hostfile has
|
||||
a very simple schema. It is simply a list of objects that define each host via
|
||||
a hostname to ssh to and a list of IPs to utilize for the communication.
|
||||
|
||||
.. code:: json
|
||||
|
||||
[
|
||||
{"ssh": "hostname1", "ips": ["123.123.1.1", "123.123.2.1"]},
|
||||
{"ssh": "hostname2", "ips": ["123.123.1.2", "123.123.2.2"]}
|
||||
]
|
||||
|
||||
You can use ``mlx.distributed_config --over ethernet`` to create a hostfile
|
||||
with IPs corresponding to the ``en0`` interface.
|
||||
|
||||
Setting up Remote Hosts
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
In order to be able to launch the script on each host we need to be able to
|
||||
connect via ssh. Moreover the input script and python binary need to be on each
|
||||
host and on the same path. A good checklist to debug errors is the following:
|
||||
|
||||
* ``ssh hostname`` works without asking for password or host confirmation
|
||||
* the python binary is available on all hosts at the same path. You can use
|
||||
``mlx.launch --print-python`` to see what that path is.
|
||||
* the script you want to run is available on all hosts at the same path
|
||||
|
||||
.. _mpi_specifics:
|
||||
|
||||
MPI Specifics
|
||||
-------------
|
||||
|
||||
One can use MPI by passing ``--backend mpi`` to ``mlx.launch``. In that case,
|
||||
``mlx.launch`` is a thin wrapper over ``mpirun``. Moreover,
|
||||
|
||||
* The IPs in the hostfile are ignored
|
||||
* The ssh connectivity requirement is stronger as every node needs to be able
|
||||
to connect to every other node
|
||||
* ``mpirun`` needs to be available on every node at the same path
|
||||
|
||||
Finally, one can pass arguments to ``mpirun`` using ``--mpi-arg``. For instance
|
||||
to choose a specific interface for the byte-transfer-layer of MPI we can call
|
||||
``mlx.launch`` as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
mlx.launch --backend mpi --mpi-arg '--mca btl_tcp_if_include en0' --hostfile hosts.json my_script.py
|
||||
|
||||
|
||||
.. _ring_specifics:
|
||||
|
||||
Ring Specifics
|
||||
--------------
|
||||
|
||||
The ring backend, which is also the default backend, can be explicitly selected
|
||||
with the argument ``--backend ring``. The ring backend has some specific
|
||||
requirements and arguments that are different to MPI:
|
||||
|
||||
* The argument ``--hosts`` only accepts IPs and not hostnames. If we need to
|
||||
ssh to a hostname that does not correspond to the IP we want to bind to we
|
||||
have to provide a hostfile.
|
||||
* ``--starting-port`` defines the port to bind to on the remote hosts.
|
||||
Specifically rank 0 for the first IP will use this port and each subsequent
|
||||
IP or rank will add 1 to this port.
|
||||
* ``--connections-per-ip`` allows us to increase the number of connections
|
||||
between neighboring nodes. This corresponds to ``--mca btl_tcp_links 2`` for
|
||||
``mpirun``.
|
@@ -21,13 +21,11 @@ Let's convert an array to NumPy and back.
|
||||
|
||||
.. note::
|
||||
|
||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert
|
||||
to ``float16`` or ``float32`` first: ``np.array(a.astype(mx.float32))``.
|
||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118
|
||||
buffer format string does not match the dtype V item size 0.``
|
||||
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
|
||||
``np.array(a.astype(mx.float32))``.
|
||||
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
|
||||
|
||||
By default, NumPy copies data to a new array. This can be prevented by creating
|
||||
an array view:
|
||||
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -37,16 +35,10 @@ an array view:
|
||||
a_view[0] = 1
|
||||
print(a[0].item()) # 1
|
||||
|
||||
.. note::
|
||||
A NumPy array view is a normal NumPy array, except that it does not own its memory.
|
||||
This means writing to the view is reflected in the original array.
|
||||
|
||||
NumPy arrays with type ``float64`` will be default converted to MLX arrays
|
||||
with type ``float32``.
|
||||
|
||||
A NumPy array view is a normal NumPy array, except that it does not own its
|
||||
memory. This means writing to the view is reflected in the original array.
|
||||
|
||||
While this is quite powerful to prevent copying arrays, it should be noted that
|
||||
external changes to the memory of arrays cannot be reflected in gradients.
|
||||
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
|
||||
|
||||
Let's demonstrate this in an example:
|
||||
|
||||
@@ -64,12 +56,11 @@ Let's demonstrate this in an example:
|
||||
|
||||
|
||||
The function ``f`` indirectly modifies the array ``x`` through a memory view.
|
||||
However, this modification is not reflected in the gradient, as seen in the
|
||||
last line outputting ``1.0``, representing the gradient of the sum operation
|
||||
alone. The squaring of ``x`` occurs externally to MLX, meaning that no
|
||||
gradient is incorporated. It's important to note that a similar issue arises
|
||||
during array conversion and copying. For instance, a function defined as
|
||||
``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
|
||||
representing the gradient of the sum operation alone.
|
||||
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
|
||||
It's important to note that a similar issue arises during array conversion and copying.
|
||||
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
|
||||
even though no in-place operations on MLX memory are executed.
|
||||
|
||||
PyTorch
|
||||
@@ -80,8 +71,7 @@ PyTorch
|
||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||
|
||||
PyTorch supports the buffer protocol, but it requires an explicit
|
||||
:obj:`memoryview`.
|
||||
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -92,8 +82,7 @@ PyTorch supports the buffer protocol, but it requires an explicit
|
||||
b = torch.tensor(memoryview(a))
|
||||
c = mx.array(b.numpy())
|
||||
|
||||
Conversion from PyTorch tensors back to arrays must be done via intermediate
|
||||
NumPy arrays with ``numpy()``.
|
||||
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
|
||||
|
||||
JAX
|
||||
---
|
||||
@@ -111,8 +100,7 @@ JAX fully supports the buffer protocol.
|
||||
TensorFlow
|
||||
----------
|
||||
|
||||
TensorFlow supports the buffer protocol, but it requires an explicit
|
||||
:obj:`memoryview`.
|
||||
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@@ -10,6 +10,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
|
||||
|
||||
# ----------------------------- Dependencies -----------------------------
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
find_package(
|
||||
Python 3.8
|
||||
COMPONENTS Interpreter Development.Module
|
||||
@@ -20,12 +21,6 @@ execute_process(
|
||||
OUTPUT_VARIABLE nanobind_ROOT)
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE MLX_ROOT)
|
||||
find_package(MLX CONFIG REQUIRED)
|
||||
|
||||
# ----------------------------- Extensions -----------------------------
|
||||
|
||||
# Add library
|
||||
|
@@ -1,15 +1,19 @@
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include "axpby/axpby.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#endif
|
||||
|
||||
#ifdef _METAL_
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@@ -17,19 +21,6 @@
|
||||
|
||||
namespace my_ext {
|
||||
|
||||
// A helper function to find the location of the current binary on disk.
|
||||
// The Metal library ("mlx_ext.mtllib"), should be in the same directory.
|
||||
std::string current_binary_dir() {
|
||||
static std::string binary_dir = []() {
|
||||
Dl_info info;
|
||||
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||
throw std::runtime_error("Unable to get current binary dir.");
|
||||
}
|
||||
return std::filesystem::path(info.dli_fname).parent_path().string();
|
||||
}();
|
||||
return binary_dir;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Operation Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -84,65 +75,136 @@ void axpby_impl(
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(y);
|
||||
encoder.set_output_array(out);
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
// Launch the CPU kernel
|
||||
encoder.dispatch([x_ptr = x.data<T>(),
|
||||
y_ptr = y.data<T>(),
|
||||
out_ptr = out.data<T>(),
|
||||
size = out.size(),
|
||||
shape = out.shape(),
|
||||
x_strides = x.strides(),
|
||||
y_strides = y.strides(),
|
||||
alpha_,
|
||||
beta_]() {
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < size; out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
|
||||
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
|
||||
// Do the element-wise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = mx::elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = mx::elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
});
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additional mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
|
||||
void Axpby::eval_cpu(
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Check the inputs (registered in the op while constructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == mx::float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::float16) {
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
|
||||
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::bfloat16) {
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
|
||||
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == mx::complex64) {
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
|
||||
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Accelerate Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const mx::array& x,
|
||||
const mx::array& y,
|
||||
mx::array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
|
||||
// This specialization requires both x and y be contiguous in the same mode
|
||||
// i.e: corresponding linear indices in both point to corresponding elements
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, mx::CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == mx::float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#else // Accelerate not available
|
||||
|
||||
/** Evaluate primitive on CPU falling back to common backend */
|
||||
void Axpby::eval_cpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Primitive Metal Backend Implementation
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -154,6 +216,7 @@ void Axpby::eval_gpu(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
auto& out = outputs[0];
|
||||
@@ -172,24 +235,25 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
std::string kname = "axpby_";
|
||||
kname += (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname += type_to_name(out);
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_";
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext", current_binary_dir());
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname, lib);
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
@@ -85,6 +85,11 @@ class Axpby : public mx::Primitive {
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(
|
||||
const std::vector<mx::array>& inputs,
|
||||
std::vector<mx::array>& outputs);
|
||||
};
|
||||
|
||||
} // namespace my_ext
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023-2025 Apple Inc.
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.25
|
||||
mlx>=0.21.0
|
||||
nanobind==2.4.0
|
||||
nanobind==2.2.0
|
||||
|
@@ -3,10 +3,8 @@ from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c_cpu = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
c_gpu = axpby(a, b, 4.0, 2.0, stream=mx.gpu)
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
|
||||
print(f"c shape: {c_cpu.shape}")
|
||||
print(f"c dtype: {c_cpu.dtype}")
|
||||
print(f"c_cpu correct: {mx.all(c_cpu == 6.0).item()}")
|
||||
print(f"c_gpu correct: {mx.all(c_gpu == 6.0).item()}")
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
|
@@ -5,7 +5,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
@@ -20,11 +19,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
# Define MLX_VERSION only in the version.cpp file.
|
||||
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||
|
||||
if(MSVC)
|
||||
# Disable some MSVC warnings to speed up compilation.
|
||||
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
||||
@@ -35,33 +29,24 @@ if(WIN32)
|
||||
set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
|
||||
if(MLX_BUILD_CPU)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cpu)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
elseif(MLX_BUILD_CPU)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/common/default_primitives.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
endif()
|
||||
|
@@ -4,11 +4,12 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
@@ -21,4 +22,45 @@ void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
}
|
||||
return Buffer{ptr};
|
||||
}
|
||||
|
||||
void CommonAllocator::free(Buffer buffer) {
|
||||
std::free(buffer.ptr());
|
||||
}
|
||||
|
||||
size_t CommonAllocator::size(Buffer buffer) const {
|
||||
if (buffer.ptr() == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
Buffer malloc_or_wait(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
|
||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
||||
scheduler::wait_for_one();
|
||||
buffer = allocator().malloc(size);
|
||||
}
|
||||
|
||||
// Try swapping if needed
|
||||
if (size && !buffer.ptr()) {
|
||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||
}
|
||||
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@@ -32,10 +32,14 @@ Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
// Wait for running tasks to finish and free up memory
|
||||
// if allocation fails
|
||||
Buffer malloc_or_wait(size_t size);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
|
||||
@@ -49,4 +53,16 @@ class Allocator {
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
private:
|
||||
CommonAllocator() = default;
|
||||
friend Allocator& allocator();
|
||||
};
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@@ -10,6 +10,20 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
auto cval = static_cast<complex64_t>(val);
|
||||
@@ -25,18 +39,7 @@ array::array(
|
||||
std::move(shape),
|
||||
dtype,
|
||||
std::move(primitive),
|
||||
std::move(inputs))) {
|
||||
if (has_primitive() && this->primitive().stream().device == Device::gpu) {
|
||||
for (auto& in : this->inputs()) {
|
||||
if (in.dtype() == float64) {
|
||||
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||
}
|
||||
}
|
||||
if (this->dtype() == float64) {
|
||||
throw std::invalid_argument("float64 is not supported on the GPU");
|
||||
}
|
||||
}
|
||||
}
|
||||
std::move(inputs))) {}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
std::vector<Shape> shapes,
|
||||
@@ -56,18 +59,6 @@ std::vector<array> array::make_arrays(
|
||||
return outputs;
|
||||
}
|
||||
|
||||
array array::unsafe_weak_copy(const array& other) {
|
||||
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
|
||||
cpy.set_data(
|
||||
other.buffer(),
|
||||
other.data_size(),
|
||||
other.strides(),
|
||||
other.flags(),
|
||||
[](auto) {});
|
||||
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
return cpy;
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
@@ -89,26 +80,22 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||
}
|
||||
|
||||
void array::detach() {
|
||||
array_desc_->primitive = nullptr;
|
||||
for (auto& s : array_desc_->siblings) {
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
for (auto& s : array_desc_->siblings) {
|
||||
s.array_desc_->inputs.clear();
|
||||
s.array_desc_->siblings.clear();
|
||||
s.array_desc_->position = 0;
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->siblings.clear();
|
||||
array_desc_->position = 0;
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
bool array::is_available() const {
|
||||
if (status() == Status::available) {
|
||||
return true;
|
||||
} else if (
|
||||
status() == Status::evaluated &&
|
||||
(!event().valid() || event().is_signaled())) {
|
||||
} else if (status() == Status::evaluated && event().is_signaled()) {
|
||||
set_status(Status::available);
|
||||
return true;
|
||||
}
|
||||
@@ -117,10 +104,7 @@ bool array::is_available() const {
|
||||
|
||||
void array::wait() {
|
||||
if (!is_available()) {
|
||||
if (event().valid()) {
|
||||
event().wait();
|
||||
detach_event();
|
||||
}
|
||||
event().wait();
|
||||
set_status(Status::available);
|
||||
}
|
||||
}
|
||||
@@ -135,8 +119,7 @@ void array::eval() {
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return (array_desc_->is_tracer && detail::in_tracing()) ||
|
||||
detail::retain_graph();
|
||||
return (array_desc_->is_tracer && in_tracing()) || retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
@@ -181,13 +164,34 @@ void array::copy_shared_buffer(const array& other) {
|
||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(
|
||||
array other,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset /* = 0 */) {
|
||||
array_desc_->data = std::move(other.array_desc_->data);
|
||||
array_desc_->strides = strides;
|
||||
array_desc_->flags = flags;
|
||||
array_desc_->data_size = data_size;
|
||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||
auto data_ptr = other.array_desc_->data_ptr;
|
||||
other.array_desc_->data_ptr = nullptr;
|
||||
array_desc_->data_ptr =
|
||||
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(array other) {
|
||||
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
array::~array() {
|
||||
if (array_desc_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Detached/detaching
|
||||
if (array_desc_->primitive == nullptr) {
|
||||
// Ignore arrays that might be detached during eval
|
||||
if (status() == array::Status::scheduled) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
57
mlx/array.h
57
mlx/array.h
@@ -10,7 +10,6 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/small_vector.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -19,8 +18,8 @@ class Primitive;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using ShapeElem = int32_t;
|
||||
using Shape = SmallVector<ShapeElem>;
|
||||
using Strides = SmallVector<int64_t>;
|
||||
using Shape = std::vector<ShapeElem>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
@@ -36,29 +35,29 @@ class array {
|
||||
explicit array(const std::complex<float>& val, Dtype dtype = complex64);
|
||||
|
||||
template <typename It>
|
||||
explicit array(
|
||||
array(
|
||||
It data,
|
||||
Shape shape,
|
||||
Dtype dtype =
|
||||
TypeToDtype<typename std::iterator_traits<It>::value_type>());
|
||||
|
||||
template <typename T>
|
||||
explicit array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
||||
array(std::initializer_list<T> data, Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Special case so empty lists default to float32. */
|
||||
explicit array(std::initializer_list<float> data);
|
||||
array(std::initializer_list<float> data);
|
||||
|
||||
/* Special case so array({}, type) is an empty array. */
|
||||
explicit array(std::initializer_list<int> data, Dtype dtype);
|
||||
array(std::initializer_list<int> data, Dtype dtype);
|
||||
|
||||
template <typename T>
|
||||
explicit array(
|
||||
array(
|
||||
std::initializer_list<T> data,
|
||||
Shape shape,
|
||||
Dtype dtype = TypeToDtype<T>());
|
||||
|
||||
/* Build an array from a buffer */
|
||||
explicit array(
|
||||
array(
|
||||
allocator::Buffer data,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
@@ -200,13 +199,6 @@ class array {
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/**
|
||||
* Get a new array that refers to the same data as the input but with a
|
||||
* non-owning pointer to it. Note the array is detached from the graph and has
|
||||
* no inputs, siblings or primitive.
|
||||
*/
|
||||
static array unsafe_weak_copy(const array& other);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
std::uintptr_t id() const {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||
@@ -225,10 +217,6 @@ class array {
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||
o.buffer = allocator::Buffer(nullptr);
|
||||
o.d = [](allocator::Buffer) {};
|
||||
}
|
||||
~Data() {
|
||||
d(buffer);
|
||||
}
|
||||
@@ -344,11 +332,11 @@ class array {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
// Return the shared pointer to the array::Data struct
|
||||
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
std::shared_ptr<Data> data_shared_ptr() const {
|
||||
return array_desc_->data;
|
||||
}
|
||||
|
||||
// Return a raw pointer to the arrays data
|
||||
template <typename T>
|
||||
T* data() {
|
||||
@@ -361,10 +349,15 @@ class array {
|
||||
}
|
||||
|
||||
enum Status {
|
||||
// The output of a computation which has not been scheduled.
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
// The ouptut of a computation which has been scheduled but `eval_*` has
|
||||
// not yet been called on the array's primitive. A possible
|
||||
// status of `x` in `auto x = a + b; eval(x);`
|
||||
scheduled,
|
||||
|
||||
// The array's `eval_*` function has been run, but the computation is not
|
||||
// necessarily complete. The array will have memory allocated and if it is
|
||||
// not a tracer then it will be detached from the graph.
|
||||
@@ -401,10 +394,6 @@ class array {
|
||||
array_desc_->event = std::move(e);
|
||||
}
|
||||
|
||||
void detach_event() const {
|
||||
array_desc_->event = Event{};
|
||||
}
|
||||
|
||||
// Mark the array as a tracer array (true) or not.
|
||||
void set_tracer(bool is_tracer) {
|
||||
array_desc_->is_tracer = is_tracer;
|
||||
@@ -430,6 +419,15 @@ class array {
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
void move_shared_buffer(
|
||||
array other,
|
||||
const Strides& strides,
|
||||
Flags flags,
|
||||
size_t data_size,
|
||||
size_t offset = 0);
|
||||
|
||||
void move_shared_buffer(array other);
|
||||
|
||||
void overwrite_descriptor(const array& other) {
|
||||
array_desc_ = other.array_desc_;
|
||||
}
|
||||
@@ -596,9 +594,6 @@ void array::init(It src) {
|
||||
case float32:
|
||||
std::copy(src, src + size(), data<float>());
|
||||
break;
|
||||
case float64:
|
||||
std::copy(src, src + size(), data<double>());
|
||||
break;
|
||||
case bfloat16:
|
||||
std::copy(src, src + size(), data<bfloat16_t>());
|
||||
break;
|
||||
|
8
mlx/backend/accelerate/CMakeLists.txt
Normal file
8
mlx/backend/accelerate/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp)
|
20
mlx/backend/accelerate/conv.cpp
Normal file
20
mlx/backend/accelerate/conv.cpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
|
||||
// TODO: Add accelerate based optimizations for CPU conv
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
253
mlx/backend/accelerate/matmul.cpp
Normal file
253
mlx/backend/accelerate/matmul.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/backend/accelerate/utils.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<bool, size_t, array> check_transpose(const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matmul_cblas_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[matmul_cblas] on CPU currently only supports float32");
|
||||
}
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[matmul_cblas] on CPU currently only supports float32");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_cblas_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
inline void matmul_bnns_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
// TODO: Update to utilize BNNS broadcasting
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
/* float alpha = */ alpha,
|
||||
/* float beta = */ beta,
|
||||
/* bool transA = */ a_transposed,
|
||||
/* bool transB = */ b_transposed,
|
||||
/* bool quadratic = */ false,
|
||||
/* bool a_is_weights = */ false,
|
||||
/* bool b_is_weights = */ false,
|
||||
/* BNNSNDArrayDescriptor iA_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{lda, (M * K) / lda, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, lda, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor iB_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, ldb, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
/* BNNSNDArrayDescriptor o_desc = */
|
||||
BNNSNDArrayDescriptor{
|
||||
/* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet,
|
||||
/* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix,
|
||||
|
||||
/* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{N, M, 0, 0, 0, 0, 0, 0},
|
||||
/* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */
|
||||
{1, N, 0, 0, 0, 0, 0, 0},
|
||||
|
||||
/* void * _Nullable data = */ nullptr,
|
||||
/* BNNSDataType data_type = */ bnns_dtype,
|
||||
|
||||
/* void * _Nullable table_data = */ nullptr,
|
||||
/* BNNSDataType table_data_type = */ bnns_dtype,
|
||||
|
||||
/* float data_scale = */ 1.0,
|
||||
/* float data_bias = */ 0.0,
|
||||
},
|
||||
};
|
||||
|
||||
auto bnns_filter =
|
||||
BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr);
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
BNNSFilterApplyTwoInput(
|
||||
bnns_filter,
|
||||
a.data<uint8_t>() +
|
||||
elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(),
|
||||
b.data<uint8_t>() +
|
||||
elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(),
|
||||
out.data<uint8_t>() + M * N * i * out.itemsize());
|
||||
}
|
||||
|
||||
BNNSFilterDestroy(bnns_filter);
|
||||
}
|
||||
|
||||
inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
|
||||
// TODO: Update to utilize BNNS broadcasting
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_bnns_general(a_pre, b_pre, out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
int tile_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
int tX = (X + tile_size - 1) / tile_size;
|
||||
int tY = (Y + tile_size - 1) / tile_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
int loc_x = i * tile_size;
|
||||
int loc_y = j * tile_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
|
||||
int size_x = std::min(tile_size, X - loc_x);
|
||||
int size_y = std::min(tile_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() == float32) {
|
||||
return matmul_cblas(inputs[0], inputs[1], out);
|
||||
}
|
||||
return matmul_bnns(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy(c, out, ctype);
|
||||
|
||||
if (out.dtype() == float32) {
|
||||
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
602
mlx/backend/accelerate/primitives.cpp
Normal file
602
mlx/backend/accelerate/primitives.cpp
Normal file
@@ -0,0 +1,602 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define DEFAULT(primitive) \
|
||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Use the default implementation for the following primitives
|
||||
DEFAULT(Arange)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x + y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x + y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvacosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvacoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvasinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvasinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvatanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
if (a.is_donatable()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
int size = a.data_size();
|
||||
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvatanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.flags().contiguous) {
|
||||
// Use accelerate functions if possible
|
||||
if (in.dtype() == float32 && out.dtype() == uint32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfixu32(
|
||||
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == float32 && out.dtype() == int32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == uint32 && out.dtype() == float32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfltu32(
|
||||
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == int32 && out.dtype() == float32) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
}
|
||||
}
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvcosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvcoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x / y; },
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x / y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Expm1::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpm1f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
assert(in.dtype() == out.dtype());
|
||||
if (in.data_size() == 1 && out.dtype() == float32) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
vDSP_vfill(in.data<float>(), out.data<float>(), 1, out.size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
vvlogf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
case Base::two:
|
||||
vvlog2f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
case Base::ten:
|
||||
vvlog10f(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x * y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
int size = a.size();
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == float32 &&
|
||||
in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int stride = in.shape(axis_);
|
||||
int count = in.size() / stride;
|
||||
const float* input = in.data<float>();
|
||||
float* output = out.data<float>();
|
||||
float s = 1.0;
|
||||
if (!reverse_) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
vDSP_vrsum(input - 1, 1, &s, output, 1, stride);
|
||||
input += stride;
|
||||
output += stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < count; i++) {
|
||||
input += stride - 1;
|
||||
output += stride - 1;
|
||||
vDSP_vrsum(input + 1, -1, &s, output, -1, stride);
|
||||
input++;
|
||||
output++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvsinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvsinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
if (recip_) {
|
||||
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
vvsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
}
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x - y; },
|
||||
[](const auto* s, const auto* vec, auto* o, auto n) {
|
||||
float minus_1 = -1;
|
||||
vDSP_vsmsa(
|
||||
(const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
float val = -(*s);
|
||||
vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n);
|
||||
},
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return x - y; },
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* vec, const auto* s, auto* o, auto n) {
|
||||
int val = -(*s);
|
||||
vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n);
|
||||
},
|
||||
UseDefaultBinaryOp());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvtanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
vvtanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
117
mlx/backend/accelerate/quantized.cpp
Normal file
117
mlx/backend/accelerate/quantized.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void _qmm_t_4_64(
|
||||
float* result,
|
||||
const float* x,
|
||||
const uint32_t* w,
|
||||
const float* scales,
|
||||
const float* biases,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int B,
|
||||
bool batched_w) {
|
||||
constexpr int bits = 4;
|
||||
constexpr int group_size = 64;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
int w_els = N * K / pack_factor;
|
||||
int g_els = w_els * pack_factor / group_size;
|
||||
|
||||
for (int i = 0; i < B; i++) {
|
||||
for (int m = 0; m < M; m++) {
|
||||
const uint32_t* w_local = w;
|
||||
const float* scales_local = scales;
|
||||
const float* biases_local = biases;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
const simd_float16* x_local = (simd_float16*)x;
|
||||
simd_float16 sum = 0;
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
float scale = *scales_local++;
|
||||
float bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw += 2) {
|
||||
// TODO: vectorize this properly
|
||||
simd_uint16 wi;
|
||||
for (int e = 0; e < 2; e++) {
|
||||
uint32_t wii = *w_local++;
|
||||
for (int p = 0; p < 8; p++) {
|
||||
wi[e * 8 + p] = wii & bitmask;
|
||||
wii >>= bits;
|
||||
}
|
||||
}
|
||||
simd_float16 wf = simd_float(wi);
|
||||
wf *= scale;
|
||||
wf += bias;
|
||||
|
||||
sum += (*x_local) * wf;
|
||||
x_local++;
|
||||
}
|
||||
}
|
||||
|
||||
*result = simd_reduce_add(sum);
|
||||
result++;
|
||||
}
|
||||
|
||||
x += K;
|
||||
}
|
||||
if (batched_w) {
|
||||
w += w_els;
|
||||
scales += g_els;
|
||||
biases += g_els;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 4);
|
||||
|
||||
auto& x = inputs[0];
|
||||
auto& w = inputs[1];
|
||||
auto& scales = inputs[2];
|
||||
auto& biases = inputs[3];
|
||||
|
||||
bool condition =
|
||||
(transpose_ && x.flags().row_contiguous && w.flags().row_contiguous &&
|
||||
scales.flags().row_contiguous && biases.flags().row_contiguous &&
|
||||
x.dtype() == float32 && bits_ == 4 && group_size_ == 64);
|
||||
|
||||
if (condition) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
int B = x.size() / K / M;
|
||||
bool batched_w = w.ndim() > 2;
|
||||
_qmm_t_4_64(
|
||||
out.data<float>(),
|
||||
x.data<float>(),
|
||||
w.data<uint32_t>(),
|
||||
scales.data<float>(),
|
||||
biases.data<float>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
B,
|
||||
batched_w);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
139
mlx/backend/accelerate/reduce.cpp
Normal file
139
mlx/backend/accelerate/reduce.cpp
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct MinReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return simd_min(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct MaxReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return std::max(a, b);
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct SumReduction {
|
||||
T operator()(const T& a, const T& b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT operator()(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT, int N, typename Reduction>
|
||||
struct StridedReduce {
|
||||
void operator()(const T* x, T* accum, int size, size_t stride) {
|
||||
Reduction op;
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
size_t s = stride;
|
||||
T* a = accum;
|
||||
while (s >= N) {
|
||||
*(VT*)a = op((*(VT*)x), (*(VT*)a));
|
||||
x += N;
|
||||
a += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
*a = op(*a, *x);
|
||||
a++;
|
||||
x++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.dtype() == float32) {
|
||||
if (reduce_type_ == Reduce::Sum) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
0,
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
SumReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float acc;
|
||||
vDSP_sve((const float*)x, 1, &acc, size);
|
||||
(*accum) += acc;
|
||||
},
|
||||
[](auto* accum, auto x) { *accum += x; });
|
||||
return;
|
||||
} else if (reduce_type_ == Reduce::Max) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
-std::numeric_limits<float>::infinity(),
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
MaxReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float max;
|
||||
vDSP_maxv((const float*)x, 1, &max, size);
|
||||
(*accum) = (*accum < max) ? max : *accum;
|
||||
},
|
||||
[](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; });
|
||||
return;
|
||||
} else if (reduce_type_ == Reduce::Min) {
|
||||
reduction_op<float, float>(
|
||||
in,
|
||||
out,
|
||||
axes_,
|
||||
std::numeric_limits<float>::infinity(),
|
||||
StridedReduce<
|
||||
float,
|
||||
simd_float16,
|
||||
16,
|
||||
MinReduction<float, simd_float16>>(),
|
||||
[](const auto* x, auto* accum, int size) {
|
||||
float min;
|
||||
vDSP_minv((const float*)x, 1, &min, size);
|
||||
(*accum) = (*accum > min) ? min : *accum;
|
||||
},
|
||||
[](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; });
|
||||
return;
|
||||
}
|
||||
}
|
||||
// TODO: Add integer addition and min/max using the templates above and
|
||||
// simd_int16 and friends.
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
393
mlx/backend/accelerate/softmax.cpp
Normal file
393
mlx/backend/accelerate/softmax.cpp
Normal file
@@ -0,0 +1,393 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* Compute exp(x) in an optimizer friendly way as follows:
|
||||
*
|
||||
* First change the problem to computing 2**y where y = x / ln(2).
|
||||
*
|
||||
* Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part
|
||||
* `ipart` and y2 is fractional part. For the integer part we perform bit
|
||||
* shifting and for the fractional part we use a polynomial approximation.
|
||||
*
|
||||
* The algorithm and constants of the polynomial taken from
|
||||
* https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them
|
||||
* from Cephes math library.
|
||||
*
|
||||
* Note: The implementation below is a general fast exp. There could be faster
|
||||
* implementations for numbers strictly < 0.
|
||||
*/
|
||||
inline simd_float16 simd_fast_exp(simd_float16 x_init) {
|
||||
auto x = x_init * 1.442695; // multiply with log_2(e)
|
||||
simd_float16 ipart, fpart;
|
||||
simd_int16 epart;
|
||||
x = simd_clamp(x, -80, 80);
|
||||
ipart = simd::floor(x + 0.5);
|
||||
fpart = x - ipart;
|
||||
|
||||
x = 1.535336188319500e-4f;
|
||||
x = x * fpart + 1.339887440266574e-3f;
|
||||
x = x * fpart + 9.618437357674640e-3f;
|
||||
x = x * fpart + 5.550332471162809e-2f;
|
||||
x = x * fpart + 2.402264791363012e-1f;
|
||||
x = x * fpart + 6.931472028550421e-1f;
|
||||
x = x * fpart + 1.000000000000000f;
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
epart = (simd_int(ipart) + 127) << 23;
|
||||
|
||||
// Avoid supressing NaNs
|
||||
simd_int16 eq = (x_init == x_init);
|
||||
return simd_bitselect(x_init, (*(simd_float16*)&epart) * x, eq);
|
||||
}
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
/**
|
||||
* The ARM neon equivalent of the fast exp above.
|
||||
*/
|
||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
|
||||
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
|
||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||
|
||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
int16x8_t epart = vcvtq_s16_f16(ipart);
|
||||
epart = vaddq_s16(epart, vdupq_n_s16(15));
|
||||
epart = vshlq_n_s16(epart, 10);
|
||||
|
||||
return vmulq_f16(vreinterpretq_f16_s16(epart), x);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of folding maximum for ARM neon. This should possibly be
|
||||
* refactored out of softmax.cpp at some point.
|
||||
*/
|
||||
inline float16_t neon_reduce_max(float16x8_t x) {
|
||||
float16x4_t y;
|
||||
y = vpmax_f16(vget_low_f16(x), vget_high_f16(x));
|
||||
y = vpmax_f16(y, y);
|
||||
y = vpmax_f16(y, y);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of folding sum for ARM neon. This should possibly be
|
||||
* refactored out of softmax.cpp at some point.
|
||||
*/
|
||||
inline float16_t neon_reduce_add(float16x8_t x) {
|
||||
float16x4_t y;
|
||||
float16x4_t zero = vdup_n_f16(0);
|
||||
y = vpadd_f16(vget_low_f16(x), vget_high_f16(x));
|
||||
y = vpadd_f16(y, zero);
|
||||
y = vpadd_f16(y, zero);
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct NeonFp16SimdOps {
|
||||
VT init(T a) {
|
||||
return vdupq_n_f16(a);
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return vld1q_f16(a);
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
vst1q_f16(dst, x);
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return vmaxq_f16(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return neon_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return vaddq_f16(a, b);
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return vsubq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return vmulq_f16(a, b);
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return vmulq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return neon_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return neon_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct AccelerateSimdOps {
|
||||
VT init(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return *(VT*)a;
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
*(VT*)dst = x;
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return simd_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return simd_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
const T* current_in_ptr;
|
||||
T* current_out_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += M) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
VT vmaximum = ops.init(-std::numeric_limits<float>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
VT vals;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vals = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vals[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vmaximum = ops.max(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT maximum = ops.reduce_max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
VT vnormalizer = ops.init(0.0);
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
VT vexp;
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
vexp = ops.load(current_in_ptr);
|
||||
} else {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
}
|
||||
vexp = ops.exp(ops.sub(vexp, maximum));
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, vexp);
|
||||
}
|
||||
vnormalizer = ops.add(vnormalizer, vexp);
|
||||
current_in_ptr += N;
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = ops.reduce_add(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
if (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr = _exp;
|
||||
}
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
current_out_ptr++;
|
||||
}
|
||||
normalizer = 1 / normalizer;
|
||||
|
||||
// Normalize
|
||||
current_out_ptr = out_ptr;
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
ops.store(current_out_ptr, ops.mul(*(VT*)current_out_ptr, normalizer));
|
||||
} else {
|
||||
VT vexp;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
vexp[i] = static_cast<AccT>(current_in_ptr[i]);
|
||||
}
|
||||
vexp = ops.mul(ops.exp(ops.sub(vexp, maximum)), normalizer);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
current_out_ptr[i] = vexp[i];
|
||||
}
|
||||
current_in_ptr += N;
|
||||
}
|
||||
current_out_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
while (s-- > 0) {
|
||||
if constexpr (std::is_same<T, AccT>::value) {
|
||||
*current_out_ptr *= normalizer;
|
||||
} else {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
*current_out_ptr = static_cast<T>(_exp * normalizer);
|
||||
current_in_ptr++;
|
||||
}
|
||||
current_out_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy(x, x_copy, CopyType::General);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
array in = check_input(std::move(inputs[0]));
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::invalid_argument(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<
|
||||
float,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
break;
|
||||
case float16:
|
||||
if (precise_) {
|
||||
softmax<
|
||||
float16_t,
|
||||
float,
|
||||
simd_float16,
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
eval(inputs, out); // Redirect to common backend for consistency
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
eval(inputs, out);
|
||||
break;
|
||||
case complex64:
|
||||
eval(inputs, out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
28
mlx/backend/accelerate/utils.h
Normal file
28
mlx/backend/accelerate/utils.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
||||
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
||||
switch (kindof(mlx_dtype)) {
|
||||
case Dtype::Kind::b:
|
||||
return BNNSDataTypeBoolean;
|
||||
case Dtype::Kind::u:
|
||||
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
||||
case Dtype::Kind::i:
|
||||
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
||||
case Dtype::Kind::f:
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
||||
case Dtype::Kind::V:
|
||||
return BNNSDataTypeBFloat16;
|
||||
case Dtype::Kind::c:
|
||||
throw std::invalid_argument("BNNS does not support complex types");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,9 +1,71 @@
|
||||
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(COMPILER ${CMAKE_C_COMPILER})
|
||||
set(CLANG TRUE)
|
||||
else()
|
||||
set(COMPILER ${CMAKE_CXX_COMPILER})
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
set(SHELL_EXT ps1)
|
||||
set(SHELL_CMD powershell -ExecutionPolicy Bypass -File)
|
||||
else()
|
||||
set(SHELL_EXT sh)
|
||||
set(SHELL_CMD /bin/bash)
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND
|
||||
${SHELL_CMD} ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.${SHELL_EXT}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ${COMPILER}
|
||||
${PROJECT_SOURCE_DIR} ${CLANG} ${CMAKE_SYSTEM_PROCESSOR}
|
||||
DEPENDS make_compiled_preamble.${SHELL_EXT}
|
||||
compiled_preamble.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/half_types.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/fp16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/bf16.h
|
||||
${PROJECT_SOURCE_DIR}/mlx/types/complex.h
|
||||
ops.h)
|
||||
|
||||
add_custom_target(cpu_compiled_preamble DEPENDS compiled_preamble.cpp)
|
||||
|
||||
add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp)
|
||||
|
||||
if(IOS)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
|
||||
endif()
|
||||
|
74
mlx/backend/common/arange.h
Normal file
74
mlx/backend/common/arange.h
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void arange(T start, T next, array& out, size_t size) {
|
||||
auto ptr = out.data<T>();
|
||||
auto step_size = next - start;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ptr[i] = start;
|
||||
start += step_size;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void arange(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
double start,
|
||||
double step) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
break;
|
||||
case uint8:
|
||||
arange<uint8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint16:
|
||||
arange<uint16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint32:
|
||||
arange<uint32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case uint64:
|
||||
arange<uint64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int8:
|
||||
arange<int8_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int16:
|
||||
arange<int16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int32:
|
||||
arange<int32_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case int64:
|
||||
arange<int64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float16:
|
||||
arange<float16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case float32:
|
||||
arange<float>(start, start + step, out, out.size());
|
||||
break;
|
||||
case bfloat16:
|
||||
arange<bfloat16_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
case complex64:
|
||||
arange<complex64_t>(start, start + step, out, out.size());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
112
mlx/backend/common/arg_reduce.cpp
Normal file
112
mlx/backend/common/arg_reduce.cpp
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
Strides strides = in.strides();
|
||||
Shape shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto in_ptr = in.data<InT>() + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) {
|
||||
op(j, (*in_ptr), &ind_v, &v);
|
||||
}
|
||||
out.data<uint32_t>()[i] = ind_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
ArgReduce::ReduceType rtype,
|
||||
int axis) {
|
||||
switch (rtype) {
|
||||
case ArgReduce::ArgMin: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
if (x < (*y)) {
|
||||
(*y) = x;
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
break;
|
||||
}
|
||||
case ArgReduce::ArgMax: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
if (x > (*y)) {
|
||||
(*y) = x;
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgReduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint8:
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int8:
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int16:
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int32:
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int64:
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float16:
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float32:
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
331
mlx/backend/common/binary.cpp
Normal file
331
mlx/backend/common/binary.cpp
Normal file
@@ -0,0 +1,331 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/binary_two.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, U, Op> opsv(op);
|
||||
DefaultVectorScalar<T, U, Op> opvs(op);
|
||||
DefaultVectorVector<T, U, Op> opvv(op);
|
||||
binary_op<T, U>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(const array& a, const array& b, array& out, Op op) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
comparison_op<bool, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
comparison_op<uint8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
comparison_op<uint16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
comparison_op<uint32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
comparison_op<uint64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
comparison_op<int8_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
comparison_op<int16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
comparison_op<int32_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
comparison_op<int64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float16:
|
||||
comparison_op<float16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case float32:
|
||||
comparison_op<float, bool>(a, b, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
comparison_op<bfloat16_t, bool>(a, b, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
comparison_op<complex64_t, bool>(a, b, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Add::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Add());
|
||||
}
|
||||
|
||||
void DivMod::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto integral_op = [](auto x, auto y) {
|
||||
return std::make_pair(x / y, x % y);
|
||||
};
|
||||
auto float_op = [](auto x, auto y) {
|
||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||
};
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, integral_op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, float_op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case complex64:
|
||||
// Should never get here
|
||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Divide());
|
||||
}
|
||||
|
||||
void Remainder::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Remainder());
|
||||
}
|
||||
|
||||
void Equal::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
if (equal_nan_) {
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NaNEqual());
|
||||
} else {
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Equal());
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater());
|
||||
}
|
||||
|
||||
void GreaterEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual());
|
||||
}
|
||||
|
||||
void Less::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less());
|
||||
}
|
||||
|
||||
void LessEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual());
|
||||
}
|
||||
|
||||
void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::LogAddExp());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[logaddexp] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[logaddexp] Cannot compute logaddexp for arrays with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Maximum());
|
||||
}
|
||||
|
||||
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Minimum());
|
||||
}
|
||||
|
||||
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Multiply());
|
||||
}
|
||||
|
||||
void NotEqual::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual());
|
||||
}
|
||||
|
||||
void Power::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Power());
|
||||
}
|
||||
|
||||
void Subtract::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Subtract());
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto dispatch_type = [&a, &b, &out](auto op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[BitwiseBinary::eval_cpu] Type not supported");
|
||||
break;
|
||||
}
|
||||
};
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
dispatch_type(detail::BitwiseAnd());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
dispatch_type(detail::BitwiseOr());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
dispatch_type(detail::BitwiseXor());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
dispatch_type(detail::LeftShift());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
dispatch_type(detail::RightShift());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == float16) {
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
std::ostringstream err;
|
||||
err << "[arctan2] Does not support " << out.dtype();
|
||||
throw std::invalid_argument(err.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[arctan2] Cannot compute inverse tangent for arrays"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
@@ -8,6 +9,8 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class BinaryOpType {
|
||||
ScalarScalar,
|
||||
ScalarVector,
|
||||
@@ -16,7 +19,7 @@ enum class BinaryOpType {
|
||||
General,
|
||||
};
|
||||
|
||||
inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
BinaryOpType bopt;
|
||||
if (a.data_size() == 1 && b.data_size() == 1) {
|
||||
bopt = BinaryOpType::ScalarScalar;
|
||||
@@ -34,24 +37,29 @@ inline BinaryOpType get_binary_op_type(const array& a, const array& b) {
|
||||
return bopt;
|
||||
}
|
||||
|
||||
inline void set_binary_op_output_data(
|
||||
void set_binary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
BinaryOpType bopt) {
|
||||
BinaryOpType bopt,
|
||||
bool donate_with_move = false) {
|
||||
bool b_donatable = is_donatable(b, out);
|
||||
bool a_donatable = is_donatable(a, out);
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b_donatable) {
|
||||
out.copy_shared_buffer(b);
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(b.data_size() * out.itemsize()),
|
||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -59,10 +67,14 @@ inline void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
if (a_donatable) {
|
||||
out.copy_shared_buffer(a);
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -70,12 +82,20 @@ inline void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
if (a_donatable) {
|
||||
out.copy_shared_buffer(a);
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
} else if (b_donatable) {
|
||||
out.copy_shared_buffer(b);
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -83,15 +103,428 @@ inline void set_binary_op_output_data(
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
||||
out.copy_shared_buffer(a);
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
} else if (
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
out.copy_shared_buffer(b);
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
struct UseDefaultBinaryOp {};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorScalar {
|
||||
Op op;
|
||||
|
||||
DefaultVectorScalar(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *b;
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, scalar);
|
||||
dst++;
|
||||
a++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultScalarVector {
|
||||
Op op;
|
||||
|
||||
DefaultScalarVector(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
T scalar = *a;
|
||||
while (size-- > 0) {
|
||||
*dst = op(scalar, *b);
|
||||
dst++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultVectorVector {
|
||||
Op op;
|
||||
|
||||
DefaultVectorVector(Op op_) : op(op_) {}
|
||||
|
||||
void operator()(const T* a, const T* b, U* dst, int size) {
|
||||
while (size-- > 0) {
|
||||
*dst = op(*a, *b);
|
||||
dst++;
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op, int D, bool Strided>
|
||||
void binary_op_dims(
|
||||
const T* a,
|
||||
const T* b,
|
||||
U* out,
|
||||
Op op,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides,
|
||||
int axis) {
|
||||
auto stride_a = a_strides[axis];
|
||||
auto stride_b = b_strides[axis];
|
||||
auto stride_out = out_strides[axis];
|
||||
auto N = shape[axis];
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
if constexpr (D > 1) {
|
||||
binary_op_dims<T, U, Op, D - 1, Strided>(
|
||||
a, b, out, op, shape, a_strides, b_strides, out_strides, axis + 1);
|
||||
} else {
|
||||
if constexpr (Strided) {
|
||||
op(a, b, out, stride_out);
|
||||
} else {
|
||||
*out = op(*a, *b);
|
||||
}
|
||||
}
|
||||
out += stride_out;
|
||||
a += stride_a;
|
||||
b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, bool Strided, typename Op>
|
||||
void binary_op_dispatch_dims(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
int dim,
|
||||
const Shape& shape,
|
||||
const Strides& a_strides,
|
||||
const Strides& b_strides,
|
||||
const Strides& out_strides) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_ptr = out.data<U>();
|
||||
switch (dim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 3:
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, dim - 3);
|
||||
ContiguousIterator b_it(shape, b_strides, dim - 3);
|
||||
auto stride = out_strides[dim - 4];
|
||||
for (int64_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 3, Strided>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
dim - 3);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
typename Op,
|
||||
typename OpSV,
|
||||
typename OpVS,
|
||||
typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
const auto& a_strides = new_strides[0];
|
||||
const auto& b_strides = new_strides[1];
|
||||
const auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = new_shape.size();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvv, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opvs, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true>(
|
||||
a, b, out, opsv, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false>(
|
||||
a, b, out, op, dim, new_shape, a_strides, b_strides, strides);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op, typename OpSV, typename OpVS, typename OpVV>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
OpSV opsv,
|
||||
OpVS opvs,
|
||||
OpVV opvv) {
|
||||
// TODO: The following mess of constexpr evaluations can probably be achieved
|
||||
// with template specializations and overloading. Would it be simpler?
|
||||
|
||||
if constexpr (std::is_same<decltype(opsv), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::value) {
|
||||
if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// All ops are UseDefaultBinaryOp (why oh why would someone call that?)
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv and opvs were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opsv and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
DefaultScalarVector<T, T, Op>(op),
|
||||
opvs,
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opsv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, DefaultScalarVector<T, T, Op>(op), opvs, opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvs), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
if (std::is_same<decltype(opvv), UseDefaultBinaryOp>::value) {
|
||||
// opvs and opvv were UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
op,
|
||||
opsv,
|
||||
DefaultVectorScalar<T, T, Op>(op),
|
||||
DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// opvs was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, DefaultVectorScalar<T, T, Op>(op), opvv);
|
||||
}
|
||||
} else if constexpr (std::is_same<decltype(opvv), UseDefaultBinaryOp>::
|
||||
value) {
|
||||
// opvv was UseDefaultBinaryOp
|
||||
binary_op<T, T>(
|
||||
a, b, out, op, opsv, opvs, DefaultVectorVector<T, T, Op>(op));
|
||||
} else {
|
||||
// All ops provided
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
DefaultScalarVector<T, T, Op> opsv(op);
|
||||
DefaultVectorScalar<T, T, Op> opvs(op);
|
||||
DefaultVectorVector<T, T, Op> opvv(op);
|
||||
binary_op<T, T>(a, b, out, op, opsv, opvs, opvv);
|
||||
}
|
||||
|
||||
template <typename... Ops>
|
||||
void binary(const array& a, const array& b, array& out, Ops... ops) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, ops...);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, ops...);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, ops...);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, ops...);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, ops...);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, ops...);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, out, ops...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,8 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -58,14 +58,14 @@ void binary_op_dispatch_dims(
|
||||
Op op) {
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out_a.strides()});
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& out_strides = strides[2];
|
||||
const T* a_ptr = a.data<T>();
|
||||
const T* b_ptr = b.data<T>();
|
||||
U* out_a_ptr = out_a.data<U>();
|
||||
U* out_b_ptr = out_b.data<U>();
|
||||
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& out_strides = strides[2];
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
@@ -120,10 +120,14 @@ template <typename T, typename U = T, typename Op>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
BinaryOpType bopt) {
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::General) {
|
||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
|
||||
@@ -137,14 +141,14 @@ void binary_op(
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||
for (size_t i = 0; i < b.data_size(); ++i) {
|
||||
for (size_t i = 0; i < b.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
@@ -161,6 +165,55 @@ void binary_op(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, outputs, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,24 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,11 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,157 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(
|
||||
size_t page_size,
|
||||
std::function<size_t(T*)> get_size,
|
||||
std::function<void(T*)> free)
|
||||
: page_size_(page_size),
|
||||
get_size_(std::move(get_size)),
|
||||
free_(std::move(free)) {}
|
||||
|
||||
~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
BufferCache(const BufferCache&) = delete;
|
||||
BufferCache& operator=(const BufferCache&) = delete;
|
||||
|
||||
T* reuse_from_cache(size_t size) {
|
||||
// Find the closest buffer in pool.
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
if (it == buffer_pool_.end() ||
|
||||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Collect from the cache.
|
||||
T* buf = it->second->buf;
|
||||
pool_size_ -= it->first;
|
||||
|
||||
// Remove from record.
|
||||
remove_from_list(it->second);
|
||||
buffer_pool_.erase(it);
|
||||
return buf;
|
||||
}
|
||||
|
||||
void recycle_to_cache(T* buf) {
|
||||
assert(buf);
|
||||
// Add to cache.
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
size_t size = get_size_(buf);
|
||||
pool_size_ += size;
|
||||
buffer_pool_.emplace(size, bh);
|
||||
}
|
||||
|
||||
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
// Release buffer.
|
||||
size_t size = get_size_(tail_->buf);
|
||||
total_bytes_freed += size;
|
||||
free_(tail_->buf);
|
||||
n_release++;
|
||||
|
||||
// Remove from record.
|
||||
auto its = buffer_pool_.equal_range(size);
|
||||
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||
return el.second == tail_;
|
||||
});
|
||||
assert(it != buffer_pool_.end());
|
||||
buffer_pool_.erase(it);
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
int clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
free_(holder->buf);
|
||||
n_release++;
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
size_t cache_size() const {
|
||||
return pool_size_;
|
||||
}
|
||||
|
||||
size_t page_size() const {
|
||||
return page_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||
|
||||
BufferHolder* prev{nullptr};
|
||||
BufferHolder* next{nullptr};
|
||||
T* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add) {
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void remove_from_list(BufferHolder* to_remove) {
|
||||
if (to_remove->prev && to_remove->next) { // if middle
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
delete to_remove;
|
||||
}
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_{nullptr};
|
||||
BufferHolder* tail_{nullptr};
|
||||
size_t pool_size_{0};
|
||||
|
||||
const size_t page_size_;
|
||||
std::function<size_t(T*)> get_size_;
|
||||
std::function<void(T*)> free_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
74
mlx/backend/common/cholesky.cpp
Normal file
74
mlx/backend/common/cholesky.cpp
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
// (A)ᵀ = A
|
||||
// and that a column-major lower triangular matrix is a row-major upper
|
||||
// triangular matrix, so uplo is the opposite of what we would expect from
|
||||
// upper
|
||||
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
float* matrix = factor.data<float>();
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(spotrf)
|
||||
(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[cholesky] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
||||
}
|
||||
cholesky_impl(inputs[0], output, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -40,20 +39,31 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
move_or_copy(in, out, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
move_or_copy(inputs[0], out);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval(
|
||||
@@ -62,7 +72,7 @@ void CustomTransforms::eval(
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
move_or_copy(inputs[j], outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +81,7 @@ void Depends::eval(
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
move_or_copy(inputs[i], outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,12 +92,12 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
for (auto ax : axes_) {
|
||||
strides.insert(strides.begin() + ax, 1);
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
@@ -135,9 +145,6 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
case bfloat16:
|
||||
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
|
||||
break;
|
||||
case float64:
|
||||
*out.data<double>() = static_cast<double>(numel);
|
||||
break;
|
||||
case complex64:
|
||||
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
|
||||
break;
|
||||
@@ -194,7 +201,7 @@ void shared_buffer_reshape(
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Split::eval(
|
||||
@@ -260,12 +267,12 @@ void Squeeze::eval(const std::vector<array>& inputs, array& out) {
|
||||
strides.push_back(in.strides(i));
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, in.flags(), in.data_size());
|
||||
move_or_copy(in, out, strides, in.flags(), in.data_size());
|
||||
}
|
||||
|
||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
move_or_copy(inputs[0], out);
|
||||
}
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
@@ -299,7 +306,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,7 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -14,8 +15,6 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
return print_float_constant<float16_t>(os, x);
|
||||
case bfloat16:
|
||||
return print_float_constant<bfloat16_t>(os, x);
|
||||
case float64:
|
||||
return print_float_constant<double>(os, x);
|
||||
case complex64:
|
||||
return print_complex_constant<complex64_t>(os, x);
|
||||
case int8:
|
||||
@@ -52,8 +51,6 @@ std::string get_type_string(Dtype d) {
|
||||
return "float16_t";
|
||||
case bfloat16:
|
||||
return "bfloat16_t";
|
||||
case float64:
|
||||
return "double";
|
||||
case complex64:
|
||||
return "complex64_t";
|
||||
case bool_:
|
||||
@@ -82,6 +79,55 @@ std::string get_type_string(Dtype d) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
NodeNamer namer;
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (auto& x : inputs) {
|
||||
namer.get_name(x);
|
||||
}
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const Shape& shape) {
|
||||
@@ -113,8 +159,10 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous) {
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers /* = false */) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
Strides strides;
|
||||
@@ -128,8 +176,13 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() && is_constant(i)) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o++].move_shared_buffer(in);
|
||||
} else {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
if (strides.empty() && in.size() == outputs[0].size()) {
|
||||
@@ -140,7 +193,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
@@ -156,86 +209,21 @@ void compiled_allocate_outputs(
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
is_constant(i)) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
} else {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
}
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant) {
|
||||
const Shape& shape = out.shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
if (contiguous) {
|
||||
return {true, shape, {}};
|
||||
}
|
||||
|
||||
std::vector<Strides> strides_vec{out.strides()};
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants.
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip scalar inputs.
|
||||
const auto& x = inputs[i];
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
size_t j = 0;
|
||||
for (; j < shape.size() - x.ndim(); ++j) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides_vec.push_back(std::move(xstrides));
|
||||
}
|
||||
|
||||
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||
}
|
||||
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (const auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (const auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,8 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -10,20 +11,24 @@
|
||||
namespace mlx::core {
|
||||
|
||||
inline bool is_static_cast(const Primitive& p) {
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
return (
|
||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids);
|
||||
|
||||
std::string get_type_string(Dtype d);
|
||||
|
||||
template <typename T>
|
||||
void print_float_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
if constexpr (std::is_same_v<T, double>) {
|
||||
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
||||
} else {
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
||||
}
|
||||
os << x.item<T>() << std::setprecision(old_precision);
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< x.item<T>() << std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -57,19 +62,9 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous);
|
||||
|
||||
// Collapse contiguous dims ignoring scalars and constants.
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant);
|
||||
|
||||
// Return whether the kernel should use large index.
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous);
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous,
|
||||
bool move_buffers = false);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -7,12 +7,9 @@
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cpu/compiled_preamble.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
#include "mlx/backend/common/compiled_preamble.h"
|
||||
#include "mlx/backend/common/jit_compiler.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
|
||||
@@ -40,10 +37,7 @@ struct CompilerCache {
|
||||
std::shared_mutex mtx;
|
||||
};
|
||||
|
||||
static CompilerCache& cache() {
|
||||
static CompilerCache cache_;
|
||||
return cache_;
|
||||
};
|
||||
static CompilerCache cache{};
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
@@ -59,16 +53,14 @@ void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::function<std::string(void)>& source_builder) {
|
||||
{
|
||||
std::shared_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
std::shared_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
std::unique_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string source_code = source_builder();
|
||||
@@ -113,22 +105,22 @@ void* compile(
|
||||
source_file << source_code;
|
||||
source_file.close();
|
||||
|
||||
try {
|
||||
JitCompiler::exec(JitCompiler::build_command(
|
||||
output_dir, source_file_name, shared_lib_name));
|
||||
} catch (const std::exception& error) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"[Compile::eval_cpu] Failed to compile function {0}: {1}",
|
||||
kernel_name,
|
||||
error.what()));
|
||||
std::string command = JitCompiler::build_command(
|
||||
output_dir, source_file_name, shared_lib_name);
|
||||
auto return_code = system(command.c_str());
|
||||
if (return_code) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
|
||||
<< " with error code " << return_code << "." << std::endl;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// load library
|
||||
cache().libs.emplace_back(shared_lib_path);
|
||||
cache.libs.emplace_back(shared_lib_path);
|
||||
|
||||
// Load function
|
||||
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||
if (!fun) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||
@@ -136,7 +128,7 @@ void* compile(
|
||||
<< dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
cache().kernels.insert({kernel_name, fun});
|
||||
cache.kernels.insert({kernel_name, fun});
|
||||
return fun;
|
||||
}
|
||||
|
||||
@@ -146,9 +138,18 @@ inline void build_kernel(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
bool contiguous,
|
||||
int ndim) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
@@ -157,28 +158,25 @@ inline void build_kernel(
|
||||
#endif
|
||||
|
||||
// Start the kernel
|
||||
os << "void " << kernel_name
|
||||
<< "(int* shape, int64_t** strides, void** args) {" << std::endl;
|
||||
os << "void " << kernel_name << "(void** args) {" << std::endl;
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
// Skip constants from the input list
|
||||
if (is_constant(i)) {
|
||||
if (is_constant(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
auto tstr = get_type_string(x.dtype());
|
||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
// Scalars and contiguous need no strides
|
||||
if (!is_scalar(x) && !contiguous) {
|
||||
os << " const int64_t* " << xname << "_strides = strides["
|
||||
<< strides_index++ << "];" << std::endl;
|
||||
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,8 +186,10 @@ inline void build_kernel(
|
||||
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
|
||||
<< "*)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
// Add output size
|
||||
if (contiguous) {
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
|
||||
} else {
|
||||
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
|
||||
}
|
||||
|
||||
@@ -203,11 +203,10 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(i)) {
|
||||
if (is_constant(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
@@ -231,7 +230,7 @@ inline void build_kernel(
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
} else {
|
||||
os << x.primitive().name();
|
||||
x.primitive().print(os);
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
@@ -257,9 +256,8 @@ inline void build_kernel(
|
||||
} else {
|
||||
for (int d = ndim - 1; d >= 0; --d) {
|
||||
// Update pointers
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
if (is_constant(i) || is_scalar(x)) {
|
||||
for (auto& x : inputs) {
|
||||
if (is_constant(x) || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
auto& xname = namer.get_name(x);
|
||||
@@ -281,33 +279,63 @@ inline void build_kernel(
|
||||
void Compiled::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
|
||||
// Collect function input arguments.
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
auto& x = inputs[i];
|
||||
args.push_back((void*)x.data<void>());
|
||||
|
||||
if (contiguous || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the input to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < shape.size() - x.ndim(); j++) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides.push_back(std::move(xstrides));
|
||||
args.push_back(strides.back().data());
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
int ndim = shape.size();
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
kernel_name += std::to_string(ndim);
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||
auto fn_ptr = compile(kernel_name, [&]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << get_kernel_preamble() << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
@@ -317,7 +345,7 @@ void Compiled::eval_cpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
is_constant_,
|
||||
constant_ids_,
|
||||
contiguous,
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
@@ -325,26 +353,19 @@ void Compiled::eval_cpu(
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, false);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
if (contiguous) {
|
||||
if (!contiguous) {
|
||||
args.push_back((void*)outputs[0].shape().data());
|
||||
} else {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = reinterpret_cast<void (*)(int*, int64_t**, void**)>(fn_ptr);
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable {
|
||||
SmallVector<int64_t*> strides_ptrs;
|
||||
for (auto& s : strides) {
|
||||
strides_ptrs.push_back(s.data());
|
||||
}
|
||||
fun(shape.data(), strides_ptrs.data(), args.data());
|
||||
});
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
fun(args.data());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/compile_impl.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -18,7 +17,7 @@ void Compiled::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error(
|
||||
"[Compiled::eval_cpu] CPU compilation not supported on the platform.");
|
||||
"[Compiled::eval_cpu] CPU compialtion not supported on the platform.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -5,8 +5,7 @@
|
||||
// clang-format off
|
||||
#include "mlx/types/half_types.h"
|
||||
#include "mlx/types/complex.h"
|
||||
#include "mlx/backend/cpu/unary_ops.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
// clang-format on
|
||||
|
||||
const char* get_kernel_preamble();
|
1183
mlx/backend/common/conv.cpp
Normal file
1183
mlx/backend/common/conv.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,10 +3,8 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -14,19 +12,18 @@ namespace {
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_single(const array& src, array& dst) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto val = static_cast<DstT>(src.data<SrcT>()[0]);
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
auto size = dst.size();
|
||||
auto val = static_cast<DstT>(src_ptr[0]);
|
||||
std::fill_n(dst_ptr, size, val);
|
||||
for (int i = 0; i < dst.size(); ++i) {
|
||||
dst_ptr[i] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_vector(const array& src, array& dst) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
auto size = src.data_size();
|
||||
std::copy(src_ptr, src_ptr + size, dst_ptr);
|
||||
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
@@ -61,57 +58,36 @@ void copy_general_general(
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
const std::optional<array>& dynamic_i_offset,
|
||||
const std::optional<array>& dynamic_o_offset) {
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
auto i_offset_ptr =
|
||||
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
|
||||
auto o_offset_ptr =
|
||||
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
|
||||
auto size = src.size();
|
||||
int64_t o_offset) {
|
||||
if (data_shape.empty()) {
|
||||
auto val = static_cast<DstT>(*src_ptr);
|
||||
auto val = static_cast<DstT>(*(src.data<SrcT>() + i_offset));
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
*dst_ptr = val;
|
||||
return;
|
||||
}
|
||||
auto [shape, strides] =
|
||||
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
||||
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>() + o_offset;
|
||||
int ndim = shape.size();
|
||||
if (ndim < 3) {
|
||||
if (i_offset_ptr) {
|
||||
src_ptr += i_offset_ptr[0];
|
||||
}
|
||||
if (o_offset_ptr) {
|
||||
dst_ptr += o_offset_ptr[0];
|
||||
}
|
||||
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
}
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
return;
|
||||
}
|
||||
if (i_offset_ptr) {
|
||||
src_ptr += i_offset_ptr[0];
|
||||
}
|
||||
if (o_offset_ptr) {
|
||||
dst_ptr += o_offset_ptr[0];
|
||||
}
|
||||
|
||||
ContiguousIterator in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator out(shape, strides[1], ndim - 3);
|
||||
auto stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
||||
for (int64_t elem = 0; elem < size; elem += stride) {
|
||||
for (int64_t elem = 0; elem < src.size(); elem += stride) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr + in.loc,
|
||||
dst_ptr + out.loc,
|
||||
@@ -127,15 +103,7 @@ void copy_general_general(
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
dst.strides(),
|
||||
0,
|
||||
0,
|
||||
std::nullopt,
|
||||
std::nullopt);
|
||||
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
@@ -146,9 +114,7 @@ void copy_general(
|
||||
const Strides& i_strides,
|
||||
const Strides&,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
const std::optional<array>& dynamic_i_offset,
|
||||
const std::optional<array>& dynamic_o_offset) {
|
||||
int64_t o_offset) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
@@ -156,9 +122,7 @@ void copy_general(
|
||||
i_strides,
|
||||
make_contiguous_strides(data_shape),
|
||||
i_offset,
|
||||
o_offset,
|
||||
dynamic_i_offset,
|
||||
dynamic_o_offset);
|
||||
o_offset);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
@@ -170,9 +134,7 @@ inline void copy_general(const array& src, array& dst) {
|
||||
src.strides(),
|
||||
make_contiguous_strides(src.shape()),
|
||||
0,
|
||||
0,
|
||||
std::nullopt,
|
||||
std::nullopt);
|
||||
0);
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
@@ -229,9 +191,6 @@ void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float64:
|
||||
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
@@ -281,9 +240,6 @@ inline void copy_inplace_dispatch(
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float64:
|
||||
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
break;
|
||||
@@ -295,34 +251,38 @@ inline void copy_inplace_dispatch(
|
||||
|
||||
} // namespace
|
||||
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch(
|
||||
[src = array::unsafe_weak_copy(src),
|
||||
dst = array::unsafe_weak_copy(dst),
|
||||
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype) {
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
|
||||
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
bool donated = set_copy_output_data(src, dst, ctype);
|
||||
if (donated && src.dtype() == dst.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
void copy(const array& src, array& dst, CopyType ctype) {
|
||||
// Allocate the output
|
||||
switch (ctype) {
|
||||
case CopyType::Vector:
|
||||
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
|
||||
dst.copy_shared_buffer(src);
|
||||
} else {
|
||||
auto size = src.data_size();
|
||||
dst.set_data(
|
||||
allocator::malloc_or_wait(size * dst.itemsize()),
|
||||
size,
|
||||
src.strides(),
|
||||
src.flags());
|
||||
}
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
dst.set_data(allocator::malloc_or_wait(dst.nbytes()));
|
||||
break;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_cpu_inplace(src, dst, ctype, stream);
|
||||
copy_inplace(src, dst, ctype);
|
||||
}
|
||||
|
||||
void copy_cpu_inplace(
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
@@ -330,57 +290,24 @@ void copy_cpu_inplace(
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
auto weak_copy_if_set = [](auto x) -> std::optional<array> {
|
||||
if (x) {
|
||||
return array::unsafe_weak_copy(*x);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
encoder.dispatch(
|
||||
[src = array::unsafe_weak_copy(src),
|
||||
dst = array::unsafe_weak_copy(dst),
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset,
|
||||
ctype,
|
||||
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
|
||||
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset,
|
||||
dynamic_i_offset,
|
||||
dynamic_o_offset);
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
array contiguous_copy_cpu(const array& arr, Stream stream) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||
return arr_copy;
|
||||
CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -22,25 +23,17 @@ enum class CopyType {
|
||||
GeneralGeneral
|
||||
};
|
||||
|
||||
inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
return true;
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
197
mlx/backend/common/default_primitives.cpp
Normal file
197
mlx/backend/common/default_primitives.cpp
Normal file
@@ -0,0 +1,197 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define DEFAULT(primitive) \
|
||||
void primitive::eval_cpu(const std::vector<array>& inputs, array& out) { \
|
||||
primitive::eval(inputs, out); \
|
||||
}
|
||||
|
||||
#define DEFAULT_MULTI(primitive) \
|
||||
void primitive::eval_cpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
primitive::eval(inputs, outputs); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
DEFAULT(Abs)
|
||||
DEFAULT(Add)
|
||||
DEFAULT(Arange)
|
||||
DEFAULT(ArcCos)
|
||||
DEFAULT(ArcCosh)
|
||||
DEFAULT(ArcSin)
|
||||
DEFAULT(ArcSinh)
|
||||
DEFAULT(ArcTan)
|
||||
DEFAULT(ArcTan2)
|
||||
DEFAULT(ArcTanh)
|
||||
DEFAULT(ArgPartition)
|
||||
DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
DEFAULT(Exp)
|
||||
DEFAULT(ExpandDims)
|
||||
DEFAULT(Expm1)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
DEFAULT(Log)
|
||||
DEFAULT(Log1p)
|
||||
DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(Multiply)
|
||||
DEFAULT(Negative)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(Power)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scan)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Select)
|
||||
DEFAULT(Sigmoid)
|
||||
DEFAULT(Sign)
|
||||
DEFAULT(Sin)
|
||||
DEFAULT(Sinh)
|
||||
DEFAULT(Slice)
|
||||
DEFAULT(SliceUpdate)
|
||||
DEFAULT(Softmax)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT_MULTI(Split)
|
||||
DEFAULT(Square)
|
||||
DEFAULT(Squeeze)
|
||||
DEFAULT(Sqrt)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Subtract)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
DEFAULT_MULTI(Eigh)
|
||||
|
||||
namespace {
|
||||
|
||||
inline void matmul_common_general(
|
||||
const array& a_pre,
|
||||
const array& b_pre,
|
||||
array& out,
|
||||
float alpha = 1.0f,
|
||||
float beta = 0.0f) {
|
||||
auto check_transpose = [](const array& arr) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre);
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
a_transposed ? CblasTrans : CblasNoTrans, // transA
|
||||
b_transposed ? CblasTrans : CblasNoTrans, // transB
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha, // alpha
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
|
||||
lda,
|
||||
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
|
||||
ldb,
|
||||
beta, // beta
|
||||
out.data<float>() + M * N * i,
|
||||
out.shape(-1) // ldc
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[Matmul::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return matmul_common_general(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
|
||||
copy(c, out, ctype);
|
||||
|
||||
return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
117
mlx/backend/common/eigh.cpp
Normal file
117
mlx/backend/common/eigh.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void ssyevd(
|
||||
char jobz,
|
||||
char uplo,
|
||||
float* a,
|
||||
int N,
|
||||
float* w,
|
||||
float* work,
|
||||
int lwork,
|
||||
int* iwork,
|
||||
int liwork) {
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(ssyevd)
|
||||
(
|
||||
/* jobz = */ &jobz,
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ a,
|
||||
/* lda = */ &N,
|
||||
/* w = */ w,
|
||||
/* work = */ work,
|
||||
/* lwork = */ &lwork,
|
||||
/* iwork = */ iwork,
|
||||
/* liwork = */ &liwork,
|
||||
/* info = */ &info);
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
auto vectors = compute_eigenvectors_
|
||||
? outputs[1]
|
||||
: array(a.shape(), a.dtype(), nullptr, {});
|
||||
|
||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
||||
|
||||
copy(
|
||||
a,
|
||||
vectors,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
// are in the columns of the output
|
||||
auto flags = vectors.flags();
|
||||
auto strides = vectors.strides();
|
||||
auto ndim = a.ndim();
|
||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||
|
||||
if (a.size() > 1) {
|
||||
flags.row_contiguous = false;
|
||||
if (ndim > 2) {
|
||||
flags.col_contiguous = false;
|
||||
} else {
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
|
||||
}
|
||||
|
||||
auto vec_ptr = vectors.data<float>();
|
||||
auto eig_ptr = values.data<float>();
|
||||
|
||||
char jobz = compute_eigenvectors_ ? 'V' : 'N';
|
||||
auto N = a.shape(-1);
|
||||
|
||||
// Work query
|
||||
int lwork;
|
||||
int liwork;
|
||||
{
|
||||
float work;
|
||||
int iwork;
|
||||
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < a.size() / (N * N); ++i) {
|
||||
ssyevd(
|
||||
jobz,
|
||||
uplo_[0],
|
||||
vec_ptr,
|
||||
N,
|
||||
eig_ptr,
|
||||
static_cast<float*>(work_buf.buffer.raw_ptr()),
|
||||
lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
liwork);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
40
mlx/backend/common/erf.cpp
Normal file
40
mlx/backend/common/erf.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cmath>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/* Approximation to the inverse error function.
|
||||
* Based on code from:
|
||||
* https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348
|
||||
*/
|
||||
float erfinv(float a) {
|
||||
auto t = std::fma(a, 0.0f - a, 1.0f);
|
||||
t = std::log(t);
|
||||
float p;
|
||||
if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793
|
||||
p = 3.03697567e-10f; // 0x1.4deb44p-32
|
||||
p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26
|
||||
p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20
|
||||
p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16
|
||||
p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12
|
||||
p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9
|
||||
p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8
|
||||
p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2
|
||||
p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1
|
||||
} else { // maximum ulp error = 2.35002
|
||||
p = 5.43877832e-9f; // 0x1.75c000p-28
|
||||
p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23
|
||||
p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20
|
||||
p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24
|
||||
p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15
|
||||
p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13
|
||||
p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9
|
||||
p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7
|
||||
p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3
|
||||
p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
}
|
||||
return a * p;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
87
mlx/backend/common/fft.cpp
Normal file
87
mlx/backend/common/fft.cpp
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
std::vector<std::ptrdiff_t> strides_in(
|
||||
in.strides().begin(), in.strides().end());
|
||||
for (auto& s : strides_in) {
|
||||
s *= in.itemsize();
|
||||
}
|
||||
std::vector<std::ptrdiff_t> strides_out(
|
||||
out.strides().begin(), out.strides().end());
|
||||
for (auto& s : strides_out) {
|
||||
s *= out.itemsize();
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
std::vector<size_t> shape;
|
||||
if (out.dtype() == float32) {
|
||||
shape.insert(shape.end(), out.shape().begin(), out.shape().end());
|
||||
} else {
|
||||
shape.insert(shape.end(), in.shape().begin(), in.shape().end());
|
||||
}
|
||||
|
||||
float scale = 1.0f;
|
||||
if (inverse_) {
|
||||
size_t nelem = std::accumulate(
|
||||
axes_.begin(), axes_.end(), 1, [&shape](auto x, auto y) {
|
||||
return x * shape[y];
|
||||
});
|
||||
scale /= nelem;
|
||||
}
|
||||
if (in.dtype() == complex64 && out.dtype() == complex64) {
|
||||
auto in_ptr =
|
||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||
auto out_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||
pocketfft::c2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
} else if (in.dtype() == float32 && out.dtype() == complex64) {
|
||||
auto in_ptr = in.data<float>();
|
||||
auto out_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(out.data<complex64_t>());
|
||||
pocketfft::r2c(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
} else if (in.dtype() == complex64 && out.dtype() == float32) {
|
||||
auto in_ptr =
|
||||
reinterpret_cast<const std::complex<float>*>(in.data<complex64_t>());
|
||||
auto out_ptr = out.data<float>();
|
||||
pocketfft::c2r(
|
||||
shape,
|
||||
strides_in,
|
||||
strides_out,
|
||||
axes_,
|
||||
!inverse_,
|
||||
in_ptr,
|
||||
out_ptr,
|
||||
scale);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[FFT] Received unexpected input and output type combination.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -2,19 +2,18 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// n = 2^k component
|
||||
template <typename T>
|
||||
void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||
for (int b = 0; b < size / n; b++) {
|
||||
void hadamard_n(array& out, int n, int m, float scale) {
|
||||
for (int b = 0; b < out.size() / n; b++) {
|
||||
size_t loc = b * n;
|
||||
T* data_ptr = out + loc;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
int h = 1;
|
||||
int n_over_2 = n / 2;
|
||||
while (h < n) {
|
||||
@@ -37,7 +36,7 @@ void hadamard_n(T* out, int n, int m, float scale, size_t size) {
|
||||
|
||||
// m component
|
||||
template <typename T>
|
||||
void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
void hadamard_m(array& out, int n, int m, float scale) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
auto start = 1;
|
||||
@@ -52,9 +51,9 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
end = matrix.find('\n', start);
|
||||
}
|
||||
|
||||
for (int b = 0; b < size / m / n; b++) {
|
||||
for (int b = 0; b < out.size() / m / n; b++) {
|
||||
size_t loc = b * n * m;
|
||||
T* data_ptr = out + loc;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
for (int i = 0; i < n; i++) {
|
||||
std::vector<float> out(m);
|
||||
for (int j = 0; j < m; j++) {
|
||||
@@ -75,47 +74,34 @@ void hadamard_m(T* out, int n, int m, float scale, size_t size) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void hadamard(array& out, int n, int m, float scale, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
auto out_ptr = out.data<T>();
|
||||
encoder.dispatch([out_ptr, size = out.size(), n, m, scale]() {
|
||||
float n_scale = m > 1 ? 1.0 : scale;
|
||||
hadamard_n<T>(out_ptr, n, m, n_scale, size);
|
||||
if (m > 1) {
|
||||
hadamard_m<T>(out_ptr, n, m, scale, size);
|
||||
}
|
||||
});
|
||||
void hadamard(array& out, int n, int m, float scale) {
|
||||
float n_scale = m > 1 ? 1.0 : scale;
|
||||
hadamard_n<T>(out, n, m, n_scale);
|
||||
if (m > 1) {
|
||||
hadamard_m<T>(out, n, m, scale);
|
||||
}
|
||||
}
|
||||
|
||||
void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy_cpu(
|
||||
in,
|
||||
out,
|
||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
}
|
||||
copy(in, out, CopyType::General);
|
||||
|
||||
int axis = out.ndim() - 1;
|
||||
auto [n, m] = decompose_hadamard(out.shape(axis));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
return hadamard<float>(out, n, m, scale_, stream());
|
||||
return hadamard<float>(out, n, m, scale_);
|
||||
case float16:
|
||||
return hadamard<float16_t>(out, n, m, scale_, stream());
|
||||
return hadamard<float16_t>(out, n, m, scale_);
|
||||
case bfloat16:
|
||||
return hadamard<bfloat16_t>(out, n, m, scale_, stream());
|
||||
return hadamard<bfloat16_t>(out, n, m, scale_);
|
||||
default:
|
||||
throw std::invalid_argument("[hadamard] Unsupported type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
@@ -99,11 +99,7 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
if (n > (1 << 26)) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
393
mlx/backend/common/indexing.cpp
Normal file
393
mlx/backend/common/indexing.cpp
Normal file
@@ -0,0 +1,393 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename IdxT>
|
||||
inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void gather(
|
||||
const array& src,
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& slice_sizes) {
|
||||
// If the array is row contiguous then we can do a contiguous copy given
|
||||
// two conditions on the slice size:
|
||||
// - Any number of leading ones in the slice sizes are allowed
|
||||
// - All other slice sizes match the corresponding dimension except the
|
||||
// first non-singleton slice size
|
||||
// If the array is col contiguous then the reverse is the case:
|
||||
// - Any number of trailing ones in the slice sizes are allowed
|
||||
// - All other slice sizes match the corresponding dimension except the
|
||||
// first non-singleton slice size from the end
|
||||
|
||||
bool can_copy = false;
|
||||
if (src.flags().row_contiguous) {
|
||||
can_copy = true;
|
||||
|
||||
// Ignore leading 1s
|
||||
int i = 0;
|
||||
for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i)
|
||||
;
|
||||
|
||||
// Check the remaining
|
||||
i++;
|
||||
for (; i < src.ndim() && can_copy; ++i) {
|
||||
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||
}
|
||||
} else if (src.flags().col_contiguous) {
|
||||
can_copy = true;
|
||||
|
||||
// Ignore trailing 1s
|
||||
int i = slice_sizes.size() - 1;
|
||||
for (; i >= 0 && slice_sizes[i] == 1; --i)
|
||||
;
|
||||
|
||||
// Skip the next slice size and check the remaining
|
||||
i--;
|
||||
for (; i >= 0 && can_copy; --i) {
|
||||
can_copy = (src.shape(i) == slice_sizes[i]);
|
||||
}
|
||||
}
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes) {
|
||||
slice_size *= s;
|
||||
}
|
||||
size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size;
|
||||
const T* src_ptr = src.data<T>();
|
||||
T* dst_ptr = out.data<T>();
|
||||
size_t out_idx = 0;
|
||||
|
||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||
ContiguousIterator src_it;
|
||||
if (!can_copy && src.ndim() > 0) {
|
||||
src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
|
||||
}
|
||||
for (int idx = 0; idx < ind_size; idx++) {
|
||||
size_t src_idx = 0;
|
||||
for (int ii = 0; ii < inds.size(); ++ii) {
|
||||
auto ax = axes[ii];
|
||||
auto idx_loc = its[ii].loc;
|
||||
its[ii].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[ii].data<IdxT>()[idx_loc], src.shape(ax));
|
||||
src_idx += (idx_val * src.strides()[ax]);
|
||||
}
|
||||
|
||||
if (slice_size == 1) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx];
|
||||
} else if (can_copy) {
|
||||
std::copy(
|
||||
src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx);
|
||||
out_idx += slice_size;
|
||||
} else {
|
||||
for (int jj = 0; jj < slice_size; jj++) {
|
||||
dst_ptr[out_idx++] = src_ptr[src_idx + src_it.loc];
|
||||
src_it.step();
|
||||
}
|
||||
src_it.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IdxT>
|
||||
void dispatch_gather(
|
||||
const array& src,
|
||||
const std::vector<array>& inds,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
const Shape& size) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
gather<bool, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint8:
|
||||
gather<uint8_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint16:
|
||||
gather<uint16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint32:
|
||||
gather<uint32_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case uint64:
|
||||
gather<uint64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int8:
|
||||
gather<int8_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int16:
|
||||
gather<int16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int32:
|
||||
gather<int32_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case int64:
|
||||
gather<int64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float16:
|
||||
gather<float16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case float32:
|
||||
gather<float, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case bfloat16:
|
||||
gather<bfloat16_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
case complex64:
|
||||
gather<complex64_t, IdxT>(src, inds, out, axes, size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Gather::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end());
|
||||
|
||||
if (inds.empty()) {
|
||||
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case bool_:
|
||||
dispatch_gather<bool>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_gather<uint8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_gather<uint16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_gather<uint32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_gather<uint64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_gather<int8_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_gather<int16_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_gather<int32_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_gather<int64_t>(src, inds, out, axes_, slice_sizes_);
|
||||
break;
|
||||
case float16:
|
||||
case float32:
|
||||
case bfloat16:
|
||||
case complex64:
|
||||
throw std::runtime_error(
|
||||
"[Gather::eval] Cannot gather with floating point indices.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT, typename OpT>
|
||||
void scatter(
|
||||
const array& updates,
|
||||
array& out,
|
||||
const std::vector<array>& inds,
|
||||
const std::vector<int>& axes,
|
||||
const OpT& op) {
|
||||
int nind = inds.size();
|
||||
auto inds_ndim = updates.ndim() - out.ndim();
|
||||
size_t n_updates = nind ? inds[0].size() : 1;
|
||||
|
||||
Shape update_shape(
|
||||
updates.shape().begin() + inds_ndim, updates.shape().end());
|
||||
size_t update_size = 1;
|
||||
for (auto us : update_shape) {
|
||||
update_size *= us;
|
||||
}
|
||||
|
||||
std::vector<ContiguousIterator> its(inds.begin(), inds.end());
|
||||
ContiguousIterator update_it(updates);
|
||||
ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
|
||||
|
||||
for (int i = 0; i < n_updates; ++i) {
|
||||
size_t out_offset = 0;
|
||||
for (int j = 0; j < nind; ++j) {
|
||||
auto ax = axes[j];
|
||||
auto idx_loc = its[j].loc;
|
||||
its[j].step();
|
||||
auto idx_val =
|
||||
offset_neg_idx(inds[j].data<IdxT>()[idx_loc], out.shape(ax));
|
||||
out_offset += (idx_val * out.strides()[ax]);
|
||||
}
|
||||
update_it.seek(i * update_size);
|
||||
for (int j = 0; j < update_size; ++j) {
|
||||
op(updates.data<InT>()[update_it.loc],
|
||||
out.data<InT>() + out_offset + out_it.loc);
|
||||
update_it.step();
|
||||
out_it.step();
|
||||
}
|
||||
out_it.reset();
|
||||
update_it.reset();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT, typename IdxT>
|
||||
void dispatch_scatter_inds(
|
||||
array& out,
|
||||
const std::vector<array>& indices,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype) {
|
||||
switch (rtype) {
|
||||
case Scatter::None:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; });
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; });
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
scatter<InT, IdxT>(
|
||||
updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; });
|
||||
break;
|
||||
case Scatter::Max:
|
||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||
(*y) = (*y > x) ? *y : x;
|
||||
});
|
||||
break;
|
||||
case Scatter::Min:
|
||||
scatter<InT, IdxT>(updates, out, indices, axes, [](auto x, auto* y) {
|
||||
(*y) = (*y < x) ? *y : x;
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
void dispatch_scatter(
|
||||
array& out,
|
||||
const std::vector<array>& inds,
|
||||
const array& updates,
|
||||
const std::vector<int>& axes,
|
||||
Scatter::ReduceType rtype) {
|
||||
if (inds.empty()) {
|
||||
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||
return;
|
||||
}
|
||||
|
||||
switch (inds[0].dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter_inds<InT, bool>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter_inds<InT, uint8_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter_inds<InT, uint16_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter_inds<InT, uint32_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter_inds<InT, uint64_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter_inds<InT, int8_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter_inds<InT, int16_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter_inds<InT, int32_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter_inds<InT, int64_t>(out, inds, updates, axes, rtype);
|
||||
break;
|
||||
case float16:
|
||||
case float32:
|
||||
case bfloat16:
|
||||
case complex64:
|
||||
throw std::runtime_error(
|
||||
"[Scatter::eval_cpu] Cannot scatter with floating point indices.");
|
||||
}
|
||||
}
|
||||
|
||||
void Scatter::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() >= 2);
|
||||
|
||||
auto& src = inputs[0];
|
||||
std::vector<array> inds(inputs.begin() + 1, inputs.end() - 1);
|
||||
auto& updates = inputs.back();
|
||||
|
||||
// Copy src into out (copy allocates memory for out)
|
||||
copy(src, out, CopyType::General);
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
dispatch_scatter<bool>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint8:
|
||||
dispatch_scatter<uint8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint16:
|
||||
dispatch_scatter<uint16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint32:
|
||||
dispatch_scatter<uint32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case uint64:
|
||||
dispatch_scatter<uint64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int8:
|
||||
dispatch_scatter<int8_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int16:
|
||||
dispatch_scatter<int16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int32:
|
||||
dispatch_scatter<int32_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case int64:
|
||||
dispatch_scatter<int64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float16:
|
||||
dispatch_scatter<float16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case float32:
|
||||
dispatch_scatter<float>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case bfloat16:
|
||||
dispatch_scatter<bfloat16_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
case complex64:
|
||||
dispatch_scatter<complex64_t>(out, inds, updates, axes_, reduce_type_);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
120
mlx/backend/common/inverse.cpp
Normal file
120
mlx/backend/common/inverse.cpp
Normal file
@@ -0,0 +1,120 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
|
||||
int info;
|
||||
MLX_LAPACK_FUNC(strtri)
|
||||
(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
return info;
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void general_inv(array& inv, int N, int i) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ nullptr,
|
||||
/* work = */ &workspace_size,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
const char uplo = upper ? 'L' : 'U';
|
||||
const char diag = 'N';
|
||||
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: triangular inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the following
|
||||
// identity to avoid transposing (see
|
||||
// https://math.stackexchange.com/a/340234):
|
||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||
|
||||
// The inverse is computed in place, so just copy the input to the output.
|
||||
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
if (tri) {
|
||||
tri_inv(inv, N, i, upper);
|
||||
} else {
|
||||
general_inv(inv, N, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||
}
|
||||
inverse_impl(inputs[0], output, tri_, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,8 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/jit_compiler.h"
|
||||
#include "mlx/backend/common/jit_compiler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
@@ -25,6 +24,29 @@ std::vector<std::string> str_split(const std::string& str, char delimiter) {
|
||||
return tokens;
|
||||
}
|
||||
|
||||
// Run a command and get its output.
|
||||
std::string exec(const std::string& cmd) {
|
||||
std::unique_ptr<FILE, decltype(&_pclose)> pipe(
|
||||
_popen(cmd.c_str(), "r"), _pclose);
|
||||
if (!pipe) {
|
||||
throw std::runtime_error("popen() failed.");
|
||||
}
|
||||
char buffer[128];
|
||||
std::string ret;
|
||||
while (fgets(buffer, sizeof(buffer), pipe.get())) {
|
||||
ret += buffer;
|
||||
}
|
||||
// Trim trailing spaces.
|
||||
ret.erase(
|
||||
std::find_if(
|
||||
ret.rbegin(),
|
||||
ret.rend(),
|
||||
[](unsigned char ch) { return !std::isspace(ch); })
|
||||
.base(),
|
||||
ret.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Get path information about MSVC.
|
||||
struct VisualStudioInfo {
|
||||
VisualStudioInfo() {
|
||||
@@ -34,7 +56,7 @@ struct VisualStudioInfo {
|
||||
arch = "x64";
|
||||
#endif
|
||||
// Get path of Visual Studio.
|
||||
std::string vs_path = JitCompiler::exec(fmt::format(
|
||||
std::string vs_path = exec(fmt::format(
|
||||
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
|
||||
" -property installationPath",
|
||||
std::getenv("ProgramFiles(x86)")));
|
||||
@@ -42,7 +64,7 @@ struct VisualStudioInfo {
|
||||
throw std::runtime_error("Can not find Visual Studio.");
|
||||
}
|
||||
// Read the envs from vcvarsall.
|
||||
std::string envs = JitCompiler::exec(fmt::format(
|
||||
std::string envs = exec(fmt::format(
|
||||
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
|
||||
vs_path,
|
||||
arch));
|
||||
@@ -55,7 +77,7 @@ struct VisualStudioInfo {
|
||||
std::string value = line.substr(pos + 1);
|
||||
if (name == "LIB") {
|
||||
libpaths = str_split(value, ';');
|
||||
} else if (name == "VCToolsInstallDir" || name == "VCTOOLSINSTALLDIR") {
|
||||
} else if (name == "VCToolsInstallDir") {
|
||||
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
|
||||
}
|
||||
}
|
||||
@@ -88,7 +110,7 @@ std::string JitCompiler::build_command(
|
||||
"\""
|
||||
"cd /D \"{0}\" && "
|
||||
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
|
||||
"/link /out:\"{3}\" {4} 2>&1"
|
||||
"/link /out:\"{3}\" {4} >nul"
|
||||
"\"",
|
||||
dir.string(),
|
||||
info.cl_exe,
|
||||
@@ -97,57 +119,10 @@ std::string JitCompiler::build_command(
|
||||
libpaths);
|
||||
#else
|
||||
return fmt::format(
|
||||
"g++ -std=c++17 -O3 -Wall -fPIC -shared \"{0}\" -o \"{1}\" 2>&1",
|
||||
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'",
|
||||
(dir / source_file_name).string(),
|
||||
(dir / shared_lib_name).string());
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string JitCompiler::exec(const std::string& cmd) {
|
||||
#ifdef _MSC_VER
|
||||
FILE* pipe = _popen(cmd.c_str(), "r");
|
||||
#else
|
||||
FILE* pipe = popen(cmd.c_str(), "r");
|
||||
#endif
|
||||
if (!pipe) {
|
||||
throw std::runtime_error("popen() failed.");
|
||||
}
|
||||
char buffer[128];
|
||||
std::string ret;
|
||||
while (fgets(buffer, sizeof(buffer), pipe)) {
|
||||
ret += buffer;
|
||||
}
|
||||
// Trim trailing spaces.
|
||||
ret.erase(
|
||||
std::find_if(
|
||||
ret.rbegin(),
|
||||
ret.rend(),
|
||||
[](unsigned char ch) { return !std::isspace(ch); })
|
||||
.base(),
|
||||
ret.end());
|
||||
|
||||
#ifdef _MSC_VER
|
||||
int status = _pclose(pipe);
|
||||
#else
|
||||
int status = pclose(pipe);
|
||||
#endif
|
||||
if (status == -1) {
|
||||
throw std::runtime_error("pclose() failed.");
|
||||
}
|
||||
#if defined(_WIN32) || defined(__FreeBSD__)
|
||||
int code = status;
|
||||
#else
|
||||
int code = WEXITSTATUS(status);
|
||||
#endif
|
||||
if (code != 0) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to execute command with return code {0}: \"{1}\", "
|
||||
"the output is: {2}",
|
||||
code,
|
||||
cmd,
|
||||
ret));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -12,9 +12,6 @@ class JitCompiler {
|
||||
const std::filesystem::path& dir,
|
||||
const std::string& source_file_name,
|
||||
const std::string& shared_lib_name);
|
||||
|
||||
// Run a command and get its output.
|
||||
static std::string exec(const std::string& cmd);
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
33
mlx/backend/common/lapack.h
Normal file
33
mlx/backend/common/lapack.h
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
// Required for Visual Studio.
|
||||
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
||||
#ifdef _MSC_VER
|
||||
#include <complex>
|
||||
#define LAPACK_COMPLEX_CUSTOM
|
||||
#define lapack_complex_float std::complex<float>
|
||||
#define lapack_complex_double std::complex<double>
|
||||
#endif
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
|
||||
|
||||
// This is to work around a change in the function signatures of lapack >= 3.9.1
|
||||
// where functions taking char* also include a strlen argument, see a similar
|
||||
// change in OpenCV:
|
||||
// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
|
||||
#define MLX_LAPACK_FUNC(f) LAPACK_##f
|
||||
|
||||
#else
|
||||
|
||||
#define MLX_LAPACK_FUNC(f) f##_
|
||||
|
||||
#endif
|
@@ -1,10 +1,12 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -27,31 +29,33 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto read_task = [out_ptr = out.data<char>(),
|
||||
size = out.size(),
|
||||
itemsize = out.itemsize(),
|
||||
offset = offset_,
|
||||
reader = reader_,
|
||||
swap_endianness_ = swap_endianness_]() mutable {
|
||||
reader->read(out_ptr, size * itemsize, offset);
|
||||
if (swap_endianness_) {
|
||||
switch (itemsize) {
|
||||
case 2:
|
||||
swap_endianness<2>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
case 4:
|
||||
swap_endianness<4>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
case 8:
|
||||
swap_endianness<8>(reinterpret_cast<uint8_t*>(out_ptr), size);
|
||||
break;
|
||||
}
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianness_) {
|
||||
reader->read(out.data<char>(), out.nbytes(), offset);
|
||||
|
||||
if (swap_endianness_) {
|
||||
switch (out.itemsize()) {
|
||||
case 2:
|
||||
swap_endianness<2>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 4:
|
||||
swap_endianness<4>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
case 8:
|
||||
swap_endianness<8>(out.data<uint8_t>(), out.data_size());
|
||||
break;
|
||||
}
|
||||
};
|
||||
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
|
||||
}
|
||||
}
|
||||
|
||||
void Load::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
load(out, offset_, reader_, swap_endianness_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
14
mlx/backend/common/load.h
Normal file
14
mlx/backend/common/load.h
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/io/load.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void load(
|
||||
array& out,
|
||||
size_t offset,
|
||||
const std::shared_ptr<io::Reader>& reader,
|
||||
bool swap_endianess);
|
||||
|
||||
} // namespace mlx::core
|
@@ -8,7 +8,7 @@ $CL = $args[1]
|
||||
$SRCDIR = $args[2]
|
||||
|
||||
# Get command result as array.
|
||||
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/cpu/compiled_preamble.h"
|
||||
$CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/compiled_preamble.h"
|
||||
# Remove empty lines.
|
||||
# Otherwise there will be too much empty lines making the result unreadable.
|
||||
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user