mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 22:44:38 +08:00
Compare commits
2 Commits
qmm
...
fences_mus
Author | SHA1 | Date | |
---|---|---|---|
![]() |
127de8821e | ||
![]() |
3ad9031a7f |
@@ -7,6 +7,15 @@ parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
weekly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
linux_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_documentation:
|
||||
@@ -15,8 +24,8 @@ jobs:
|
||||
type: boolean
|
||||
default: false
|
||||
macos:
|
||||
xcode: "16.2.0"
|
||||
resource_class: m2pro.medium
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -29,7 +38,7 @@ jobs:
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install -r docs/requirements.txt
|
||||
pip install . -v
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
|
||||
- when:
|
||||
condition:
|
||||
not: << parameters.upload-docs >>
|
||||
@@ -61,9 +70,9 @@ jobs:
|
||||
git push -f origin gh-pages
|
||||
|
||||
linux_build_and_test:
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
docker:
|
||||
- image: cimg/python:3.9
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -75,34 +84,34 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
sudo apt-get upgrade -y
|
||||
pip install --upgrade cmake
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
||||
pip install nanobind==2.4.0
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
pip install -e ".[dev]"
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python3 setup.py develop
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
echo "stubs"
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python -m unittest discover python/tests -v
|
||||
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
python3 -m unittest discover python/tests -v
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
mkdir -p build && cd build
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
||||
make -j `nproc`
|
||||
- run:
|
||||
@@ -113,15 +122,10 @@ jobs:
|
||||
parameters:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
default: "15.2.0"
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: m2pro.medium
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -142,14 +146,13 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
||||
pip install -e . -v
|
||||
DEBUG=1 CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install -e . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
@@ -157,8 +160,7 @@ jobs:
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
||||
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
||||
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
@@ -193,34 +195,13 @@ jobs:
|
||||
name: Run Python tests with JIT
|
||||
command: |
|
||||
source env/bin/activate
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
||||
pip install -e . -v
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
||||
METAL_DEBUG_ERROR_MODE=0 \
|
||||
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
|
||||
|
||||
cuda_build_and_test:
|
||||
machine:
|
||||
image: linux-cuda-12:2023.11.1
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Python package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
python3 -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
pip install -e ".[dev]"
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
||||
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
python_version:
|
||||
@@ -228,18 +209,13 @@ jobs:
|
||||
default: "3.9"
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "16.2.0"
|
||||
default: "15.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macosx_deployment_target:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: m2pro.medium
|
||||
environment:
|
||||
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
||||
resource_class: macos.m1.medium.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
@@ -260,30 +236,22 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
||||
DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.9", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` \
|
||||
python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
@@ -300,100 +268,52 @@ jobs:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
build_env:
|
||||
extra_env:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: large
|
||||
default: "DEV_RELEASE=1"
|
||||
docker:
|
||||
- image: ubuntu:20.04
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Build wheel
|
||||
command: |
|
||||
PYTHON=python<< parameters.python_version >>
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export NEEDRESTART_MODE=a
|
||||
sudo apt-get update
|
||||
sudo apt-get upgrade -y
|
||||
TZ=Etc/UTC sudo apt-get -y install tzdata
|
||||
sudo apt-get install -y apt-utils
|
||||
sudo apt-get install -y software-properties-common
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install -y build-essential git
|
||||
apt-get update
|
||||
apt-get upgrade -y
|
||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||
apt-get install -y apt-utils
|
||||
apt-get install -y software-properties-common
|
||||
add-apt-repository -y ppa:deadsnakes/ppa
|
||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
apt-get install -y build-essential git
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install nanobind==2.4.0
|
||||
pip install --upgrade setuptools
|
||||
pip install numpy
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.build_env >> pip install ".[dev]" -v
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
pip install . -v
|
||||
pip install typing_extensions
|
||||
python setup.py generate_stubs
|
||||
python setup.py clean --all
|
||||
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
||||
bash python/scripts/repair_linux.sh
|
||||
- when:
|
||||
condition:
|
||||
equal: ["3.9", << parameters.python_version >>]
|
||||
steps:
|
||||
- run:
|
||||
name: Build common package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py clean --all
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
python -m build -w
|
||||
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload packages
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
build_cuda_release:
|
||||
parameters:
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
machine:
|
||||
image: linux-cuda-12:2024.11.1
|
||||
resource_class: gpu.nvidia.small.gen2
|
||||
steps:
|
||||
- checkout
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
- run:
|
||||
name: Build wheel
|
||||
name: Upload package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
python -m venv env
|
||||
source env/bin/activate
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
pip install twine
|
||||
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
python -m build -w
|
||||
bash python/scripts/repair_cuda.sh
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload wheelhouse/*.whl
|
||||
twine upload wheelhouse/*
|
||||
- store_artifacts:
|
||||
path: wheelhouse/
|
||||
|
||||
@@ -405,19 +325,22 @@ workflows:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- mac_build_and_test:
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
- linux_build_and_test
|
||||
- cuda_build_and_test
|
||||
- build_documentation
|
||||
|
||||
build_pypi_release:
|
||||
when:
|
||||
and:
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_release:
|
||||
filters:
|
||||
@@ -428,70 +351,8 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
build_env: "PYPI_RELEASE=1"
|
||||
- build_documentation:
|
||||
filters:
|
||||
tags:
|
||||
@@ -499,25 +360,6 @@ workflows:
|
||||
branches:
|
||||
ignore: /.*/
|
||||
upload-docs: true
|
||||
- build_linux_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
- build_cuda_release:
|
||||
filters:
|
||||
tags:
|
||||
only: /^v.*/
|
||||
branches:
|
||||
ignore: /.*/
|
||||
matrix:
|
||||
parameters:
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
|
||||
prb:
|
||||
when:
|
||||
@@ -533,11 +375,9 @@ workflows:
|
||||
requires: [ hold ]
|
||||
matrix:
|
||||
parameters:
|
||||
macosx_deployment_target: ["13.5", "14.0"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
- cuda_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
when:
|
||||
and:
|
||||
@@ -548,56 +388,27 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
||||
xcode_version: ["16.2.0", "15.0.0"]
|
||||
exclude:
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "13.5"
|
||||
xcode_version: "16.2.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "14.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.9"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.10"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.11"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.12"
|
||||
- macosx_deployment_target: "15.0"
|
||||
xcode_version: "15.0.0"
|
||||
python_version: "3.13"
|
||||
xcode_version: ["15.0.0", "15.2.0"]
|
||||
weekly_build:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.weekly_build >>
|
||||
jobs:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.linux_release >>
|
||||
jobs:
|
||||
- build_linux_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
- build_cuda_release
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@@ -36,7 +36,6 @@ share/python-wheels/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
uv.lock
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
@@ -19,7 +19,6 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@@ -9,7 +9,6 @@ if(NOT MLX_VERSION)
|
||||
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
||||
set(_patch ${CMAKE_MATCH_1})
|
||||
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
||||
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
||||
else()
|
||||
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
||||
${MLX_VERSION})
|
||||
@@ -22,7 +21,7 @@ project(
|
||||
|
||||
# ----------------------------- Setup -----------------------------
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_INSTALL_MESSAGE NEVER)
|
||||
@@ -34,7 +33,6 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
||||
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
||||
option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
||||
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
||||
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
@@ -43,6 +41,8 @@ option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
message(
|
||||
STATUS
|
||||
@@ -64,8 +64,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
message(WARNING "Building for x86_64 arch is not officially supported.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
else()
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
|
||||
endif()
|
||||
|
||||
# ----------------------------- Lib -----------------------------
|
||||
@@ -75,6 +77,7 @@ include(FetchContent)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
add_library(mlx)
|
||||
set_target_properties(mlx PROPERTIES COMPILE_WARNING_AS_ERROR ON)
|
||||
|
||||
if(MLX_BUILD_METAL)
|
||||
set(METAL_LIB "-framework Metal")
|
||||
@@ -82,10 +85,6 @@ if(MLX_BUILD_METAL)
|
||||
set(QUARTZ_LIB "-framework QuartzCore")
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL AND NOT METAL_LIB)
|
||||
message(STATUS "Metal not found. Unable to build GPU")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
@@ -215,6 +214,24 @@ else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if(MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET)
|
||||
if(${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
elseif(MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING "MPI found but mpirun is not available. Building without MPI.")
|
||||
else()
|
||||
set(MPI_FOUND FALSE)
|
||||
message(WARNING "MPI which is not OpenMPI found. Building without MPI.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message(STATUS "Downloading json")
|
||||
FetchContent_Declare(
|
||||
json
|
||||
@@ -229,9 +246,6 @@ target_include_directories(
|
||||
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>)
|
||||
|
||||
# Do not add mlx_EXPORTS define for shared library.
|
||||
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
||||
|
||||
FetchContent_Declare(
|
||||
fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
|
@@ -5,26 +5,26 @@ possible.
|
||||
|
||||
## Pull Requests
|
||||
|
||||
1. Fork and submit pull requests to the repo.
|
||||
1. Fork and submit pull requests to the repo.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
||||
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
||||
4. If you've changed APIs, update the documentation.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
5. Every PR should have passing tests and at least one review.
|
||||
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||
This should install hooks for running `black` and `clang-format` to ensure
|
||||
consistent style for C++ and python code.
|
||||
|
||||
|
||||
You can also run the formatters manually as follows:
|
||||
|
||||
```shell
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```shell
|
||||
black file.py
|
||||
```
|
||||
|
||||
|
||||
```
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```
|
||||
black file.py
|
||||
```
|
||||
|
||||
or run `pre-commit run --all-files` to check all files in the repo.
|
||||
|
||||
## Issues
|
||||
|
@@ -1,6 +1,4 @@
|
||||
include CMakeLists.txt
|
||||
include mlx.pc.in
|
||||
recursive-include mlx/ *
|
||||
include cmake/*
|
||||
include python/src/*
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
@@ -192,22 +192,6 @@ void time_reductions() {
|
||||
|
||||
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
||||
TIME(argmin_along_1);
|
||||
|
||||
auto indices = mx::array({1});
|
||||
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
||||
std::vector<int> axes{0};
|
||||
auto b = scatter(a, {indices}, updates, axes);
|
||||
mx::eval(b);
|
||||
|
||||
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
||||
TIME(max_along_0);
|
||||
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
||||
TIME(max_along_1);
|
||||
|
||||
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
||||
TIME(min_along_0);
|
||||
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
||||
TIME(min_along_1);
|
||||
}
|
||||
|
||||
void time_gather_scatter() {
|
||||
|
@@ -5,7 +5,6 @@ import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.mps
|
||||
|
||||
|
||||
@@ -45,10 +44,8 @@ def bench(f, *args):
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device == torch.device("mps"):
|
||||
if x.device != torch.device("cpu"):
|
||||
torch.mps.synchronize()
|
||||
elif x.device == torch.device("cuda"):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -102,14 +99,6 @@ def reduction(op, axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sum_and_add(axis, x, y):
|
||||
z = x.sum(axis=axis, keepdims=True)
|
||||
for i in range(50):
|
||||
z = (z + y).sum(axis=axis, keepdims=True)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax(axis, x):
|
||||
ys = []
|
||||
@@ -351,11 +340,7 @@ if __name__ == "__main__":
|
||||
args.axis.pop(0)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
device = "mps"
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
if args.cpu:
|
||||
device = "cpu"
|
||||
device = "cpu" if args.cpu else "mps"
|
||||
|
||||
types = args.dtype
|
||||
if not types:
|
||||
@@ -475,8 +460,5 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
elif args.benchmark == "sum_and_add":
|
||||
print(bench(sum_and_add, axis, *xs))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
@@ -1,107 +0,0 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
N_warmup = 10
|
||||
N_iter_bench = 100
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
|
||||
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dtype = "float32"
|
||||
shapes = (
|
||||
(4, 32, 32, 21, 3, 3, 128),
|
||||
(4, 32, 32, 21, 3, 3, 37),
|
||||
(4, 32, 32, 370, 3, 3, 370),
|
||||
(4, 32, 32, 370, 7, 7, 128),
|
||||
(2, 320, 640, 21, 7, 7, 21),
|
||||
)
|
||||
for N, H, W, C, kh, kw, O in shapes:
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
@@ -1,6 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from time import time
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
@@ -1,74 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_mm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = x @ w1.T
|
||||
x = x @ w2.T
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_mm()
|
@@ -1,84 +0,0 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
N = 1024
|
||||
D = 1024
|
||||
M = 1024
|
||||
E = 32
|
||||
I = 4
|
||||
|
||||
|
||||
def gather_sort(x, indices):
|
||||
N, M = indices.shape
|
||||
indices = indices.flatten()
|
||||
order = mx.argsort(indices)
|
||||
inv_order = mx.argsort(order)
|
||||
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
||||
|
||||
|
||||
def scatter_unsort(x, inv_order, shape=None):
|
||||
x = x[inv_order]
|
||||
if shape is not None:
|
||||
x = mx.unflatten(x, 0, shape)
|
||||
return x
|
||||
|
||||
|
||||
def gather_mm_simulate(x, w, indices):
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
for i in range(2):
|
||||
y = mx.concatenate(
|
||||
[
|
||||
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
||||
for i, j in enumerate(idx.tolist())
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
x = y[:, None]
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
|
||||
def time_gather_qmm():
|
||||
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
||||
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
||||
mx.eval(x, w1, w2, indices, sorted_indices)
|
||||
|
||||
def gather_mm(x, w1, w2, indices, sort):
|
||||
idx = indices
|
||||
inv_order = None
|
||||
if sort:
|
||||
x, idx, inv_order = gather_sort(x, indices)
|
||||
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
||||
if sort:
|
||||
x = scatter_unsort(x, inv_order, indices.shape)
|
||||
return x
|
||||
|
||||
time_fn(gather_mm, x, w1, w2, indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
||||
time_fn(gather_mm, x, w1, w2, indices, True)
|
||||
|
||||
x = mx.random.normal((N * I, D)) / 1024**0.5
|
||||
w1 = mx.random.normal((M, D)) / 1024**0.5
|
||||
w2 = mx.random.normal((D, M)) / 1024**0.5
|
||||
w1 = mx.quantize(w1)
|
||||
w2 = mx.quantize(w2)
|
||||
mx.eval(x, w1, w2)
|
||||
|
||||
def equivalent_matmul(x, w1, w2):
|
||||
x = mx.quantized_matmul(x, *w1, transpose=True)
|
||||
x = mx.quantized_matmul(x, *w2, transpose=True)
|
||||
return x
|
||||
|
||||
time_fn(equivalent_matmul, x, w1, w2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_gather_qmm()
|
@@ -1,7 +1,5 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
@@ -20,63 +18,51 @@ def layer_norm(x, w, b, eps):
|
||||
return y
|
||||
|
||||
|
||||
def time_layer_norm(N, dt):
|
||||
L = 1024
|
||||
def time_layer_norm():
|
||||
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
||||
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
||||
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_loop(f, x, w, b):
|
||||
for _ in range(32):
|
||||
x = f(x, w, b)
|
||||
return x
|
||||
|
||||
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
||||
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
||||
|
||||
def layer_norm_grad_loop(g, x, w, b):
|
||||
def layer_norm_loop(g, x, w, b):
|
||||
gx, gw, gb = x, w, b
|
||||
for _ in range(32):
|
||||
gx, gw, gb = g(gx, gw, gb, y)
|
||||
return gx, gw, gb
|
||||
|
||||
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
||||
time_fn(layer_norm_loop, g1, x, w, b)
|
||||
time_fn(layer_norm_loop, g2, x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x, w, b)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x, w, b)
|
||||
|
||||
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
||||
g1 = mx.grad(f1, argnums=(0,))
|
||||
g2 = mx.grad(f2, argnums=(0,))
|
||||
|
||||
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
w = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
b = mx.random.uniform(shape=(N,)).astype(dt)
|
||||
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
||||
x = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
w = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
b = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
y = mx.random.uniform(shape=(8, 1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x, w, b, y)
|
||||
|
||||
def layer_norm_grad_x_loop(g, x):
|
||||
def layer_norm_loop(g, x):
|
||||
gx = x
|
||||
for _ in range(32):
|
||||
gx = g(gx, y)
|
||||
return gx
|
||||
|
||||
time_fn(layer_norm_grad_x_loop, g1, x)
|
||||
time_fn(layer_norm_grad_x_loop, g2, x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
||||
time_fn(layer_norm_loop, g1, x)
|
||||
time_fn(layer_norm_loop, g2, x)
|
||||
time_fn(layer_norm_loop, mx.compile(g1), x)
|
||||
time_fn(layer_norm_loop, mx.compile(g2), x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
||||
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
||||
print(dt, n)
|
||||
time_layer_norm(n, dt)
|
||||
time_layer_norm()
|
||||
|
@@ -28,34 +28,11 @@ def bench(f, *args):
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def prepare_inputs(B, qL, kL, D, qH, kH, mask, transpose, dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
shape_q = (B, qL, qH, D) if transpose else (B, qH, qL, D)
|
||||
shape_kv = (B, kL, kH, D) if transpose else (B, kH, kL, D)
|
||||
|
||||
scale = 1.0 / math.sqrt(D)
|
||||
|
||||
q_np = np.random.normal(0.0, 1.0, shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, scale, shape_kv).astype(np_dtype)
|
||||
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
if mask is not None:
|
||||
if mask == "additive":
|
||||
mask_np = np.random.normal(0.0, 1.0, (B, qH, qL, kL)).astype(np_dtype)
|
||||
mask = mx.array(mask_np)
|
||||
elif mask == "bool":
|
||||
mask_np = np.random.uniform(0.0, 1.0, (B, qH, qL, kL)) < 0.5
|
||||
mask = mx.array(mask_np)
|
||||
|
||||
return q_mx, k_mx, v_mx, scale, mask
|
||||
def mlx_sdpa_fused_inner(q, k, v, scale):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
|
||||
|
||||
|
||||
def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
|
||||
q_dtype = q.dtype
|
||||
q = q * mx.array(scale, q_dtype)
|
||||
n_q_heads = q.shape[-3]
|
||||
@@ -64,7 +41,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
|
||||
B = q.shape[0]
|
||||
L = q.shape[2]
|
||||
kL = k.shape[2]
|
||||
|
||||
if n_repeats > 1:
|
||||
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
|
||||
@@ -72,27 +48,10 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
v = mx.expand_dims(v, 2)
|
||||
|
||||
scores = q @ mx.swapaxes(k, -1, -2)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
if mask == "causal":
|
||||
q_offset = max(0, kL - L)
|
||||
q_indices = mx.arange(q_offset, q_offset + L)
|
||||
k_indices = mx.arange(kL)
|
||||
mask = q_indices[:, None] >= k_indices[None]
|
||||
|
||||
if n_repeats > 1 and mask.ndim >= 3:
|
||||
if mask.shape[-3] == 1:
|
||||
mask = mx.expand_dims(mask, -3)
|
||||
else:
|
||||
mask = mx.unflatten(mask, -3, (n_kv_heads, n_repeats))
|
||||
|
||||
if mask.dtype == mx.bool_:
|
||||
scores = mx.where(mask, scores, -np.float32(np.inf))
|
||||
else:
|
||||
scores += mask
|
||||
|
||||
scores = mx.softmax(scores, axis=-1, precise=True)
|
||||
if f32softmax:
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
|
||||
else:
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
|
||||
out = scores @ v
|
||||
if n_repeats > 1:
|
||||
@@ -101,55 +60,74 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
|
||||
return out
|
||||
|
||||
|
||||
def mlx_fused_attn(q, k, v, scale, mask):
|
||||
return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention(f, q, k, v, scale, mask=None, transpose=False):
|
||||
if transpose:
|
||||
q_t = mx.transpose(q, (0, 2, 1, 3))
|
||||
k_t = mx.transpose(k, (0, 2, 1, 3))
|
||||
v_t = mx.transpose(v, (0, 2, 1, 3))
|
||||
o_t = f(q_t, k_t, v_t, scale=scale, mask=mask)
|
||||
return mx.transpose(o_t, (0, 2, 1, 3))
|
||||
else:
|
||||
return f(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
|
||||
def do_attention_bench(f, q, k, v, scale, mask=None, transpose=False):
|
||||
def mlx_spda_unfused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
q_out = do_attention(f, q_out, k, v, scale, mask=mask, transpose=transpose)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, dtype, transpose=True, mask_in=None
|
||||
):
|
||||
q_mx, k_mx, v_mx, scale, mask = prepare_inputs(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, mask_in, transpose, dtype
|
||||
def mlx_spda_fused(q, k, v, scale, transpose):
|
||||
q_out = q
|
||||
if transpose:
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
for i in range(N_iter_func):
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
|
||||
if transpose:
|
||||
q_out = mx.transpose(q_out, (0, 2, 1, 3))
|
||||
|
||||
mx.eval(q_out)
|
||||
return q_out
|
||||
|
||||
|
||||
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
|
||||
shape_q = (
|
||||
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
|
||||
)
|
||||
shape_kv = (
|
||||
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
|
||||
)
|
||||
|
||||
time_mlx_unfused = bench(
|
||||
do_attention_bench, mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
time_mlx_fused = bench(
|
||||
do_attention_bench, mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
|
||||
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
|
||||
|
||||
o_mlx_fused = do_attention(mlx_ref_attn, q_mx, k_mx, v_mx, scale, mask, transpose)
|
||||
o_mlx_unfused = do_attention(
|
||||
mlx_fused_attn, q_mx, k_mx, v_mx, scale, mask, transpose
|
||||
)
|
||||
scale = math.sqrt(1.0 / head_dim)
|
||||
|
||||
atol = 1e-5 if dtype == "float32" else 2e-4
|
||||
q_mx = mx.array(q_np)
|
||||
k_mx = mx.array(k_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol, rtol=atol):
|
||||
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
|
||||
|
||||
if transpose:
|
||||
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
|
||||
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
|
||||
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
|
||||
|
||||
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
|
||||
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
|
||||
print(
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}, mask: {mask_in}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
|
||||
)
|
||||
|
||||
return time_mlx_fused, time_mlx_unfused
|
||||
@@ -173,51 +151,39 @@ if __name__ == "__main__":
|
||||
( 1, 128, 128, 64, 32, 32),
|
||||
( 1, 256, 256, 64, 32, 32),
|
||||
( 1, 512, 512, 64, 32, 32),
|
||||
( 1, 1024, 1024, 64, 32, 8),
|
||||
( 1, 2048, 2048, 64, 32, 8),
|
||||
( 1, 4096, 4096, 64, 32, 8),
|
||||
( 1, 1024, 1024, 64, 32, 32),
|
||||
( 1, 2048, 2048, 64, 32, 32),
|
||||
( 1, 4096, 4096, 64, 32, 32),
|
||||
)
|
||||
|
||||
shapes_80 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 80, 32, 8),
|
||||
( 1, 2048, 2048, 80, 32, 8),
|
||||
( 1, 4096, 4096, 80, 32, 8),
|
||||
( 1, 1024, 1024, 80, 32, 32),
|
||||
( 1, 2048, 2048, 80, 32, 32),
|
||||
( 1, 4096, 4096, 80, 32, 32),
|
||||
)
|
||||
|
||||
shapes_128 = (
|
||||
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
|
||||
( 1, 1024, 1024, 128, 32, 8),
|
||||
( 1, 2048, 2048, 128, 32, 8),
|
||||
( 1, 4096, 4096, 128, 32, 8),
|
||||
( 1, 1024, 1024, 128, 32, 32),
|
||||
( 1, 2048, 2048, 128, 32, 32),
|
||||
( 1, 4096, 4096, 128, 32, 32),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
shapes = shapes_64 + shapes_80 + shapes_128
|
||||
|
||||
masks = [None, "bool", "causal"]
|
||||
|
||||
print(
|
||||
" B, qsl, ksl, hdim, n_qh, n_kvh, t, dtype, mask, t_unfs, t_fuse, diff%"
|
||||
)
|
||||
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
|
||||
for mask_in in masks:
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B,
|
||||
qsl,
|
||||
ksl,
|
||||
head_dim,
|
||||
n_q_heads,
|
||||
n_kv_heads,
|
||||
dtype,
|
||||
transpose,
|
||||
mask_in,
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:1d}, {dtype}, {str(mask_in):>8}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx_fused, time_mlx_unfused = bench_shape(
|
||||
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
|
||||
)
|
||||
diff = time_mlx_unfused / time_mlx_fused - 1.0
|
||||
t_str = 1 if transpose else 0
|
||||
print(
|
||||
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
|
@@ -51,20 +51,6 @@ def time_maximum():
|
||||
time_fn(mx.maximum, a, b)
|
||||
|
||||
|
||||
def time_max():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.max, a, 0)
|
||||
|
||||
|
||||
def time_min():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
a[1, 1] = mx.nan
|
||||
mx.eval(a)
|
||||
time_fn(mx.min, a, 0)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@@ -122,8 +108,6 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_min()
|
||||
time_max()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
time_negative()
|
||||
|
@@ -11,14 +11,13 @@ include(CMakeParseArguments)
|
||||
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
||||
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
||||
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
||||
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
||||
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
||||
# files (like headers)
|
||||
#
|
||||
# clang format on
|
||||
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
@@ -27,10 +26,6 @@ macro(mlx_build_metallib)
|
||||
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
||||
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
||||
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
||||
-frecord-sources)
|
||||
endif()
|
||||
|
||||
# Prepare metallib build command
|
||||
add_custom_command(
|
||||
|
@@ -13,7 +13,7 @@ EXCLUDE_PATTERNS = */private/*
|
||||
CREATE_SUBDIRS = NO
|
||||
FULL_PATH_NAMES = YES
|
||||
RECURSIVE = YES
|
||||
GENERATE_HTML = NO
|
||||
GENERATE_HTML = YES
|
||||
GENERATE_LATEX = NO
|
||||
GENERATE_XML = YES
|
||||
XML_PROGRAMLISTING = YES
|
||||
|
@@ -10,7 +10,7 @@ import mlx.core as mx
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = "MLX"
|
||||
copyright = "2023, Apple"
|
||||
copyright = "2023, MLX Contributors"
|
||||
author = "MLX Contributors"
|
||||
version = ".".join(mx.__version__.split(".")[:3])
|
||||
release = version
|
||||
|
@@ -8,26 +8,23 @@ MLX supports writing custom Metal kernels through the Python and C++ APIs.
|
||||
Simple Example
|
||||
--------------
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@@ -42,13 +39,8 @@ Let's write a custom kernel that computes ``exp`` elementwise:
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
|
||||
Every time you make a kernel, a new Metal library is created and possibly
|
||||
JIT compiled. To reduce the overhead from that, build the kernel once with
|
||||
:func:`fast.metal_kernel` and then use it many times.
|
||||
|
||||
.. note::
|
||||
Only pass the body of the Metal kernel in ``source``. The function
|
||||
signature is generated automatically.
|
||||
We are only required to pass the body of the Metal kernel in ``source``.
|
||||
|
||||
The full function signature will be generated using:
|
||||
|
||||
@@ -86,51 +78,44 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads
|
||||
<https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_
|
||||
function. This means we will launch ``mx.prod(grid)`` threads, subdivided into
|
||||
``threadgroup`` size threadgroups. For optimal performance, each thread group
|
||||
dimension should be less than or equal to the corresponding grid dimension.
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to :func:`ast.metal_kernel.__call__` will print the
|
||||
generated code for debugging purposes.
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
-------------------
|
||||
|
||||
:func:`fast.metal_kernel` supports an argument ``ensure_row_contiguous`` which
|
||||
is ``True`` by default. This will copy the array inputs if needed
|
||||
before the kernel is launched to ensure that the memory layout is row
|
||||
contiguous. Generally this makes writing the kernel easier, since we don't
|
||||
have to worry about gaps or the ordering of the dims when indexing.
|
||||
``mx.fast.metal_kernel`` supports an argument ``ensure_row_contiguous`` which is ``True`` by default.
|
||||
This will copy the ``mx.array`` inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous.
|
||||
Generally this makes writing the kernel easier, since we don't have to worry about gaps or the ordering of the dims
|
||||
when indexing.
|
||||
|
||||
If we want to avoid this copy, :func:`fast.metal_kernel` automatically passes
|
||||
``a_shape``, ``a_strides`` and ``a_ndim`` for each input array ``a`` if any are
|
||||
present in ``source``. We can then use MLX's built in indexing utils to fetch
|
||||
the right elements for each thread.
|
||||
If we want to avoid this copy, ``metal_kernel`` automatically passes ``a_shape``, ``a_strides`` and ``a_ndim`` for each
|
||||
input array ``a`` if any are present in ``source``.
|
||||
We can then use MLX's built in indexing utils to fetch the right elements for each thread.
|
||||
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without
|
||||
relying on a copy from ``ensure_row_contiguous``:
|
||||
Let's convert ``myexp`` above to support arbitrarily strided arrays without relying on a copy from ``ensure_row_contiguous``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
// Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
|
||||
uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
|
||||
T tmp = inp[loc];
|
||||
// Output arrays are always row contiguous
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp_strided",
|
||||
input_names=["inp"],
|
||||
output_names=["out"],
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[a],
|
||||
template=[("T", mx.float32)],
|
||||
@@ -157,139 +142,137 @@ We'll start with the following MLX implementation using standard ops:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
def grid_sample_ref(x, grid):
|
||||
N, H_in, W_in, _ = x.shape
|
||||
ix = ((grid[..., 0] + 1) * W_in - 1) / 2
|
||||
iy = ((grid[..., 1] + 1) * H_in - 1) / 2
|
||||
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
ix_nw = mx.floor(ix).astype(mx.int32)
|
||||
iy_nw = mx.floor(iy).astype(mx.int32)
|
||||
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
ix_ne = ix_nw + 1
|
||||
iy_ne = iy_nw
|
||||
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
ix_sw = ix_nw
|
||||
iy_sw = iy_nw + 1
|
||||
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
ix_se = ix_nw + 1
|
||||
iy_se = iy_nw + 1
|
||||
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
nw = (ix_se - ix) * (iy_se - iy)
|
||||
ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
se = (ix - ix_nw) * (iy - iy_nw)
|
||||
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
|
||||
I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
|
||||
I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
|
||||
I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
|
||||
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
|
||||
mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
|
||||
mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
|
||||
mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
|
||||
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
I_nw *= mask_nw[..., None]
|
||||
I_ne *= mask_ne[..., None]
|
||||
I_sw *= mask_sw[..., None]
|
||||
I_se *= mask_se[..., None]
|
||||
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
|
||||
|
||||
return output
|
||||
return output
|
||||
|
||||
Now let's use :func:`custom_function` together with :func:`fast.metal_kernel`
|
||||
Now let's use ``mx.custom_function`` together with ``mx.fast.metal_kernel``
|
||||
to write a fast GPU kernel for both the forward and backward passes.
|
||||
|
||||
First we'll implement the forward pass as a fused kernel:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
uint grid_idx = elem / C * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
int batch_idx = elem / C / gH / gW * b_stride;
|
||||
int channel_idx = elem % C;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
@mx.custom_function
|
||||
def grid_sample(x, grid):
|
||||
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
||||
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
||||
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
||||
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
||||
|
||||
assert x.ndim == 4, "`x` must be 4D."
|
||||
assert grid.ndim == 4, "`grid` must be 4D."
|
||||
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
||||
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
||||
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
||||
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
||||
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
out_shape = (B, gN, gM, C)
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample",
|
||||
input_names=["x", "grid"],
|
||||
output_names=["out"],
|
||||
source=source,
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs=[x, grid],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[out_shape],
|
||||
output_dtypes=[x.dtype],
|
||||
grid=(np.prod(out_shape), 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
For a reasonably sized input such as:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
x.shape = (8, 1024, 1024, 64)
|
||||
grid.shape = (8, 256, 256, 2)
|
||||
|
||||
On an M1 Max, we see a big performance improvement:
|
||||
|
||||
@@ -298,11 +281,11 @@ On an M1 Max, we see a big performance improvement:
|
||||
Grid Sample VJP
|
||||
---------------
|
||||
|
||||
Since we decorated ``grid_sample`` with :func:`custom_function`, we can now
|
||||
define its custom vjp transform so MLX can differentiate it.
|
||||
Since we decorated ``grid_sample`` with ``mx.custom_function``, we can now define
|
||||
its custom vjp transform so MLX can differentiate it.
|
||||
|
||||
The backwards pass requires atomically updating ``x_grad``/``grid_grad`` and so
|
||||
requires a few extra :func:`fast.metal_kernel` features:
|
||||
requires a few extra ``mx.fast.metal_kernel`` features:
|
||||
|
||||
* ``init_value=0``
|
||||
Initialize all of the kernel's outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
|
||||
@@ -316,129 +299,128 @@ We can then implement the backwards pass as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
int H = x_shape[1];
|
||||
int W = x_shape[2];
|
||||
int C = x_shape[3];
|
||||
// Pad C to the nearest larger simdgroup size multiple
|
||||
int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
|
||||
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
int gH = grid_shape[1];
|
||||
int gW = grid_shape[2];
|
||||
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
int w_stride = C;
|
||||
int h_stride = W * w_stride;
|
||||
int b_stride = H * h_stride;
|
||||
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
uint grid_idx = elem / C_padded * 2;
|
||||
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
||||
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
||||
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
int ix_nw = floor(ix);
|
||||
int iy_nw = floor(iy);
|
||||
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
int ix_ne = ix_nw + 1;
|
||||
int iy_ne = iy_nw;
|
||||
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
int ix_sw = ix_nw;
|
||||
int iy_sw = iy_nw + 1;
|
||||
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
int ix_se = ix_nw + 1;
|
||||
int iy_se = iy_nw + 1;
|
||||
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
T nw = (ix_se - ix) * (iy_se - iy);
|
||||
T ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
T sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
T se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
int batch_idx = elem / C_padded / gH / gW * b_stride;
|
||||
int channel_idx = elem % C_padded;
|
||||
int base_idx = batch_idx + channel_idx;
|
||||
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
T gix = T(0);
|
||||
T giy = T(0);
|
||||
if (channel_idx < C) {
|
||||
int cot_index = elem / C_padded * C + channel_idx;
|
||||
T cot = cotangent[cot_index];
|
||||
if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
|
||||
int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
|
||||
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
T I_nw = x[offset];
|
||||
gix -= I_nw * (iy_se - iy) * cot;
|
||||
giy -= I_nw * (ix_se - ix) * cot;
|
||||
}
|
||||
if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
|
||||
int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
|
||||
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
T I_ne = x[offset];
|
||||
gix += I_ne * (iy_sw - iy) * cot;
|
||||
giy -= I_ne * (ix - ix_sw) * cot;
|
||||
}
|
||||
if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
|
||||
int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
|
||||
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
T I_sw = x[offset];
|
||||
gix -= I_sw * (iy - iy_ne) * cot;
|
||||
giy += I_sw * (ix_ne - ix) * cot;
|
||||
}
|
||||
if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
|
||||
int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
|
||||
atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
|
||||
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
T I_se = x[offset];
|
||||
gix += I_se * (iy - iy_nw) * cot;
|
||||
giy += I_se * (ix - ix_nw) * cot;
|
||||
}
|
||||
}
|
||||
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
T gix_mult = W / 2;
|
||||
T giy_mult = H / 2;
|
||||
|
||||
@grid_sample.vjp
|
||||
def grid_sample_vjp(primals, cotangent, _):
|
||||
x, grid = primals
|
||||
B, _, _, C = x.shape
|
||||
_, gN, gM, D = grid.shape
|
||||
// Reduce across each simdgroup first.
|
||||
// This is much faster than relying purely on atomics.
|
||||
gix = simd_sum(gix);
|
||||
giy = simd_sum(giy);
|
||||
|
||||
assert D == 2, "Last dim of `grid` must be size 2."
|
||||
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
|
||||
atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
|
||||
}
|
||||
"""
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="grid_sample_grad",
|
||||
input_names=["x", "grid", "cotangent"],
|
||||
output_names=["x_grad", "grid_grad"],
|
||||
source=source,
|
||||
atomic_outputs=True,
|
||||
)
|
||||
# pad the output channels to simd group size
|
||||
# so that our `simd_sum`s don't overlap.
|
||||
simdgroup_size = 32
|
||||
C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
|
||||
grid_size = B * gN * gM * C_padded
|
||||
outputs = kernel(
|
||||
inputs=[x, grid, cotangent],
|
||||
template=[("T", x.dtype)],
|
||||
output_shapes=[x.shape, grid.shape],
|
||||
output_dtypes=[x.dtype, x.dtype],
|
||||
grid=(grid_size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
init_value=0,
|
||||
)
|
||||
return outputs[0], outputs[1]
|
||||
|
||||
There's an even larger speed up for the vjp:
|
||||
|
||||
|
@@ -93,9 +93,9 @@ Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create output arrays given input arrays. Further, a
|
||||
defines how to create outputs arrays given a input arrays. Further, a
|
||||
:class:`Primitive` has methods to run on the CPU or GPU and for function
|
||||
transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
|
||||
transformations such as ``vjp`` and ``jvp``. Lets go back to our example to be
|
||||
more concrete:
|
||||
|
||||
.. code-block:: C++
|
||||
@@ -128,7 +128,7 @@ more concrete:
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
|
||||
@@ -138,13 +138,13 @@ more concrete:
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
@@ -247,7 +247,9 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
||||
// memory, potentially waiting if the system is under memory pressure
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
@@ -391,17 +393,17 @@ below.
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -469,7 +471,7 @@ one we just defined:
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the primitive can be built with ops
|
||||
// The jvp transform on the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primitive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
@@ -481,7 +483,7 @@ one we just defined:
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return {multiply(scale_arr, tangents[0], stream())};
|
||||
}
|
||||
// If argnums = {0, 1}, we take contributions from both
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
|
||||
@@ -735,7 +737,7 @@ Let's look at a simple script and its results:
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c is correct: {mx.all(c == 6.0).item()}")
|
||||
print(f"c correct: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
@@ -743,7 +745,7 @@ Output:
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c is correct: True
|
||||
c correctness: True
|
||||
|
||||
Results
|
||||
^^^^^^^
|
||||
|
@@ -70,7 +70,6 @@ are the CPU and GPU.
|
||||
python/fft
|
||||
python/linalg
|
||||
python/metal
|
||||
python/memory_management
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
|
@@ -23,24 +23,13 @@ To install from PyPI you must meet the following requirements:
|
||||
MLX is only available on devices running macOS >= 13.5
|
||||
It is highly recommended to use macOS 14 (Sonoma)
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
MLX has a CUDA backend which you can use on any Linux platform with CUDA 12
|
||||
and SM 7.0 (Volta) and up. To install MLX with CUDA support, run:
|
||||
MLX is also available on conda-forge. To install MLX with conda do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "mlx[cuda]"
|
||||
conda install conda-forge::mlx
|
||||
|
||||
CPU-only (Linux)
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
For a CPU-only version of MLX that runs on Linux use:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install "mlx[cpu]"
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
@@ -76,8 +65,6 @@ Build Requirements
|
||||
Python API
|
||||
^^^^^^^^^^
|
||||
|
||||
.. _python install:
|
||||
|
||||
To build and install the MLX python library from source, first, clone MLX from
|
||||
`its GitHub repo <https://github.com/ml-explore/mlx>`_:
|
||||
|
||||
@@ -89,20 +76,20 @@ Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install .
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install .
|
||||
|
||||
For developing, install the package with development dependencies, and use an
|
||||
editable install:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install -e ".[dev]"
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 pip install -e ".[dev]"
|
||||
|
||||
Once the development dependencies are installed, you can build faster with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
python setup.py build_ext --inplace
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=8 python setup.py build_ext --inplace
|
||||
|
||||
Run the tests with:
|
||||
|
||||
@@ -120,8 +107,6 @@ IDE:
|
||||
C++ API
|
||||
^^^^^^^
|
||||
|
||||
.. _cpp install:
|
||||
|
||||
Currently, MLX must be built and installed from source.
|
||||
|
||||
Similarly to the python library, to build and install the MLX C++ library start
|
||||
@@ -200,7 +185,6 @@ should point to the path to the built metal library.
|
||||
|
||||
xcrun -sdk macosx --show-sdk-version
|
||||
|
||||
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -229,50 +213,6 @@ be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Linux
|
||||
^^^^^
|
||||
|
||||
To build from source on Linux (CPU only), install the BLAS and LAPACK headers.
|
||||
For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
apt-get update -y
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
From here follow the instructions to install either the :ref:`Python <python
|
||||
install>` or :ref:`C++ <cpp install>` APIs.
|
||||
|
||||
CUDA
|
||||
^^^^
|
||||
|
||||
To build from source on Linux with CUDA, install the BLAS and LAPACK headers
|
||||
and the CUDA toolkit. For example on Ubuntu, run the following:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
apt-get update -y
|
||||
apt-get -y install cuda-toolkit-12-9
|
||||
apt-get install libblas-dev liblapack-dev liblapacke-dev -y
|
||||
|
||||
|
||||
When building either the Python or C++ APIs make sure to pass the cmake flag
|
||||
``MLX_BUILD_CUDA=ON``. For example, to build the Python API run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON" pip install -e ".[dev]"
|
||||
|
||||
To build the C++ package run:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DMLX_BUILD_CUDA=ON && make -j
|
||||
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
@@ -19,8 +19,6 @@ Array
|
||||
array.ndim
|
||||
array.shape
|
||||
array.size
|
||||
array.real
|
||||
array.imag
|
||||
array.abs
|
||||
array.all
|
||||
array.any
|
||||
@@ -40,7 +38,6 @@ Array
|
||||
array.log10
|
||||
array.log1p
|
||||
array.log2
|
||||
array.logcumsumexp
|
||||
array.logsumexp
|
||||
array.max
|
||||
array.mean
|
||||
|
@@ -20,5 +20,3 @@ FFT
|
||||
irfft2
|
||||
rfftn
|
||||
irfftn
|
||||
fftshift
|
||||
ifftshift
|
||||
|
@@ -16,12 +16,9 @@ Linear Algebra
|
||||
cross
|
||||
qr
|
||||
svd
|
||||
eigvals
|
||||
eig
|
||||
eigvalsh
|
||||
eigh
|
||||
lu
|
||||
lu_factor
|
||||
pinv
|
||||
solve
|
||||
solve_triangular
|
||||
|
@@ -1,16 +0,0 @@
|
||||
Memory Management
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
@@ -8,5 +8,13 @@ Metal
|
||||
|
||||
is_available
|
||||
device_info
|
||||
get_active_memory
|
||||
get_peak_memory
|
||||
reset_peak_memory
|
||||
get_cache_memory
|
||||
set_memory_limit
|
||||
set_cache_limit
|
||||
set_wired_limit
|
||||
clear_cache
|
||||
start_capture
|
||||
stop_capture
|
||||
|
@@ -36,12 +36,10 @@ Operations
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
broadcast_arrays
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
concatenate
|
||||
contiguous
|
||||
conj
|
||||
conjugate
|
||||
convolve
|
||||
@@ -103,7 +101,6 @@ Operations
|
||||
log10
|
||||
log1p
|
||||
logaddexp
|
||||
logcumsumexp
|
||||
logical_not
|
||||
logical_and
|
||||
logical_or
|
||||
|
@@ -18,5 +18,3 @@ Common Optimizers
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
Muon
|
||||
|
@@ -9,7 +9,6 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
async_eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
|
@@ -107,16 +107,6 @@ same array:
|
||||
>>> a
|
||||
array([1, 2, 0], dtype=int32)
|
||||
|
||||
|
||||
Note, unlike NumPy, updates to the same location are nondeterministic:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
>>> a = mx.array([1, 2, 3])
|
||||
>>> a[[0, 0]] = mx.array([4, 5])
|
||||
|
||||
The first element of ``a`` could be ``4`` or ``5``.
|
||||
|
||||
Transformations of functions which use in-place updates are allowed and work as
|
||||
expected. For example:
|
||||
|
||||
|
@@ -72,7 +72,9 @@ void axpby_impl(
|
||||
float alpha_,
|
||||
float beta_,
|
||||
mx::Stream stream) {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
// Allocate the output with `malloc_or_wait` which synchronously allocates
|
||||
// memory, potentially waiting if the system is under memory pressure
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Get the CPU command encoder and register input and output arrays
|
||||
auto& encoder = mx::cpu::get_command_encoder(stream);
|
||||
@@ -158,12 +160,12 @@ void Axpby::eval_gpu(
|
||||
// Allocate output memory with strides based on specialization
|
||||
if (contiguous_kernel) {
|
||||
out.set_data(
|
||||
mx::allocator::malloc(x.data_size() * out.itemsize()),
|
||||
mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
} else {
|
||||
out.set_data(mx::allocator::malloc(out.nbytes()));
|
||||
out.set_data(mx::allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
@@ -172,11 +174,11 @@ void Axpby::eval_gpu(
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Load the metal library
|
||||
auto lib = d.get_library("mlx_ext");
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), lib);
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
|
@@ -74,9 +74,9 @@ class Axpby : public mx::Primitive {
|
||||
const std::vector<mx::array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** The name of primitive. */
|
||||
const char* name() const override {
|
||||
return "Axpby";
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
|
@@ -5,7 +5,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
@@ -18,13 +17,9 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/version.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h)
|
||||
|
||||
# Define MLX_VERSION only in the version.cpp file.
|
||||
add_library(mlx_version OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/version.cpp)
|
||||
target_compile_definitions(mlx_version PRIVATE MLX_VERSION="${MLX_VERSION}")
|
||||
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:mlx_version>)
|
||||
|
||||
if(MSVC)
|
||||
# Disable some MSVC warnings to speed up compilation.
|
||||
target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804)
|
||||
@@ -49,19 +44,5 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if(MLX_BUILD_METAL)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/no_metal.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
|
||||
else()
|
||||
target_sources(mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
|
||||
endif()
|
||||
|
||||
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/gpu)
|
||||
else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_gpu)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
|
||||
endif()
|
||||
|
@@ -4,11 +4,12 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
Buffer malloc(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
auto buffer = allocator().malloc(size, /* allow_swap */ true);
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc] Unable to allocate " << size << " bytes.";
|
||||
@@ -21,4 +22,45 @@ void free(Buffer buffer) {
|
||||
allocator().free(buffer);
|
||||
}
|
||||
|
||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||
void* ptr = std::malloc(size + sizeof(size_t));
|
||||
if (ptr != nullptr) {
|
||||
*static_cast<size_t*>(ptr) = size;
|
||||
}
|
||||
return Buffer{ptr};
|
||||
}
|
||||
|
||||
void CommonAllocator::free(Buffer buffer) {
|
||||
std::free(buffer.ptr());
|
||||
}
|
||||
|
||||
size_t CommonAllocator::size(Buffer buffer) const {
|
||||
if (buffer.ptr() == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
return *static_cast<size_t*>(buffer.ptr());
|
||||
}
|
||||
|
||||
Buffer malloc_or_wait(size_t size) {
|
||||
auto buffer = allocator().malloc(size);
|
||||
|
||||
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
|
||||
scheduler::wait_for_one();
|
||||
buffer = allocator().malloc(size);
|
||||
}
|
||||
|
||||
// Try swapping if needed
|
||||
if (size && !buffer.ptr()) {
|
||||
buffer = allocator().malloc(size, /* allow_swap = */ true);
|
||||
}
|
||||
|
||||
if (size && !buffer.ptr()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[malloc_or_wait] Unable to allocate " << size << " bytes.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@@ -32,10 +32,14 @@ Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
// Wait for running tasks to finish and free up memory
|
||||
// if allocation fails
|
||||
Buffer malloc_or_wait(size_t size);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base class for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
virtual size_t size(Buffer buffer) const = 0;
|
||||
|
||||
@@ -49,4 +53,16 @@ class Allocator {
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
virtual size_t size(Buffer buffer) const override;
|
||||
|
||||
private:
|
||||
CommonAllocator() = default;
|
||||
friend Allocator& allocator();
|
||||
};
|
||||
|
||||
} // namespace mlx::core::allocator
|
||||
|
@@ -56,18 +56,6 @@ std::vector<array> array::make_arrays(
|
||||
return outputs;
|
||||
}
|
||||
|
||||
array array::unsafe_weak_copy(const array& other) {
|
||||
auto cpy = array(other.shape(), other.dtype(), nullptr, {});
|
||||
cpy.set_data(
|
||||
other.buffer(),
|
||||
other.data_size(),
|
||||
other.strides(),
|
||||
other.flags(),
|
||||
[](auto) {});
|
||||
cpy.array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
return cpy;
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<float> data)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
Shape{static_cast<ShapeElem>(data.size())},
|
||||
|
19
mlx/array.h
19
mlx/array.h
@@ -199,13 +199,6 @@ class array {
|
||||
const std::shared_ptr<Primitive>& primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
/**
|
||||
* Get a new array that refers to the same data as the input but with a
|
||||
* non-owning pointer to it. Note the array is detached from the graph and has
|
||||
* no inputs, siblings or primitive.
|
||||
*/
|
||||
static array unsafe_weak_copy(const array& other);
|
||||
|
||||
/** A unique identifier for an array. */
|
||||
std::uintptr_t id() const {
|
||||
return reinterpret_cast<std::uintptr_t>(array_desc_.get());
|
||||
@@ -224,10 +217,6 @@ class array {
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
Data(Data&& o) : buffer(o.buffer), d(o.d) {
|
||||
o.buffer = allocator::Buffer(nullptr);
|
||||
o.d = [](allocator::Buffer) {};
|
||||
}
|
||||
~Data() {
|
||||
d(buffer);
|
||||
}
|
||||
@@ -343,11 +332,11 @@ class array {
|
||||
return allocator::allocator().size(buffer());
|
||||
}
|
||||
|
||||
// Return the shared pointer to the array::Data struct
|
||||
const std::shared_ptr<Data>& data_shared_ptr() const {
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
std::shared_ptr<Data> data_shared_ptr() const {
|
||||
return array_desc_->data;
|
||||
}
|
||||
|
||||
// Return a raw pointer to the arrays data
|
||||
template <typename T>
|
||||
T* data() {
|
||||
@@ -360,7 +349,7 @@ class array {
|
||||
}
|
||||
|
||||
enum Status {
|
||||
// The output of a computation which has not been scheduled.
|
||||
// The ouptut of a computation which has not been scheduled.
|
||||
// For example, the status of `x` in `auto x = a + b`.
|
||||
unscheduled,
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/broadcasting.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
|
@@ -44,14 +44,14 @@ inline void set_binary_op_output_data(
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc(out.itemsize()), 1, a.strides(), a.flags());
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
if (b_donatable) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(b.data_size() * out.itemsize()),
|
||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -62,7 +62,7 @@ inline void set_binary_op_output_data(
|
||||
out.copy_shared_buffer(a);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -75,7 +75,7 @@ inline void set_binary_op_output_data(
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(a.data_size() * out.itemsize()),
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
@@ -88,7 +88,7 @@ inline void set_binary_op_output_data(
|
||||
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@@ -1,24 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,11 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void broadcast(const array& in, array& out);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,157 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
class BufferCache {
|
||||
public:
|
||||
BufferCache(
|
||||
size_t page_size,
|
||||
std::function<size_t(T*)> get_size,
|
||||
std::function<void(T*)> free)
|
||||
: page_size_(page_size),
|
||||
get_size_(std::move(get_size)),
|
||||
free_(std::move(free)) {}
|
||||
|
||||
~BufferCache() {
|
||||
clear();
|
||||
}
|
||||
|
||||
BufferCache(const BufferCache&) = delete;
|
||||
BufferCache& operator=(const BufferCache&) = delete;
|
||||
|
||||
T* reuse_from_cache(size_t size) {
|
||||
// Find the closest buffer in pool.
|
||||
auto it = buffer_pool_.lower_bound(size);
|
||||
if (it == buffer_pool_.end() ||
|
||||
it->first >= std::min(2 * size, size + 2 * page_size_)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Collect from the cache.
|
||||
T* buf = it->second->buf;
|
||||
pool_size_ -= it->first;
|
||||
|
||||
// Remove from record.
|
||||
remove_from_list(it->second);
|
||||
buffer_pool_.erase(it);
|
||||
return buf;
|
||||
}
|
||||
|
||||
void recycle_to_cache(T* buf) {
|
||||
assert(buf);
|
||||
// Add to cache.
|
||||
BufferHolder* bh = new BufferHolder(buf);
|
||||
add_at_head(bh);
|
||||
size_t size = get_size_(buf);
|
||||
pool_size_ += size;
|
||||
buffer_pool_.emplace(size, bh);
|
||||
}
|
||||
|
||||
int release_cached_buffers(size_t min_bytes_to_free) {
|
||||
if (min_bytes_to_free >= 0.9 * pool_size_) {
|
||||
return clear();
|
||||
} else {
|
||||
int n_release = 0;
|
||||
size_t total_bytes_freed = 0;
|
||||
|
||||
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
|
||||
// Release buffer.
|
||||
size_t size = get_size_(tail_->buf);
|
||||
total_bytes_freed += size;
|
||||
free_(tail_->buf);
|
||||
n_release++;
|
||||
|
||||
// Remove from record.
|
||||
auto its = buffer_pool_.equal_range(size);
|
||||
auto it = std::find_if(its.first, its.second, [this](const auto& el) {
|
||||
return el.second == tail_;
|
||||
});
|
||||
assert(it != buffer_pool_.end());
|
||||
buffer_pool_.erase(it);
|
||||
remove_from_list(tail_);
|
||||
}
|
||||
|
||||
pool_size_ -= total_bytes_freed;
|
||||
return n_release;
|
||||
}
|
||||
}
|
||||
|
||||
int clear() {
|
||||
int n_release = 0;
|
||||
for (auto& [size, holder] : buffer_pool_) {
|
||||
free_(holder->buf);
|
||||
n_release++;
|
||||
delete holder;
|
||||
}
|
||||
buffer_pool_.clear();
|
||||
pool_size_ = 0;
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
return n_release;
|
||||
}
|
||||
|
||||
size_t cache_size() const {
|
||||
return pool_size_;
|
||||
}
|
||||
|
||||
size_t page_size() const {
|
||||
return page_size_;
|
||||
}
|
||||
|
||||
private:
|
||||
struct BufferHolder {
|
||||
public:
|
||||
explicit BufferHolder(T* buf_) : buf(buf_) {}
|
||||
|
||||
BufferHolder* prev{nullptr};
|
||||
BufferHolder* next{nullptr};
|
||||
T* buf;
|
||||
};
|
||||
|
||||
void add_at_head(BufferHolder* to_add) {
|
||||
if (!head_) {
|
||||
head_ = to_add;
|
||||
tail_ = to_add;
|
||||
} else {
|
||||
head_->prev = to_add;
|
||||
to_add->next = head_;
|
||||
head_ = to_add;
|
||||
}
|
||||
}
|
||||
|
||||
void remove_from_list(BufferHolder* to_remove) {
|
||||
if (to_remove->prev && to_remove->next) { // if middle
|
||||
to_remove->prev->next = to_remove->next;
|
||||
to_remove->next->prev = to_remove->prev;
|
||||
} else if (to_remove->prev && to_remove == tail_) { // if tail
|
||||
tail_ = to_remove->prev;
|
||||
tail_->next = nullptr;
|
||||
} else if (to_remove == head_ && to_remove->next) { // if head
|
||||
head_ = to_remove->next;
|
||||
head_->prev = nullptr;
|
||||
} else if (to_remove == head_ && to_remove == tail_) { // if only element
|
||||
head_ = nullptr;
|
||||
tail_ = nullptr;
|
||||
}
|
||||
|
||||
delete to_remove;
|
||||
}
|
||||
|
||||
std::multimap<size_t, BufferHolder*> buffer_pool_;
|
||||
BufferHolder* head_{nullptr};
|
||||
BufferHolder* tail_{nullptr};
|
||||
size_t pool_size_{0};
|
||||
|
||||
const size_t page_size_;
|
||||
std::function<size_t(T*)> get_size_;
|
||||
std::function<void(T*)> free_;
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -43,6 +42,23 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
}
|
||||
Strides strides(out.ndim(), 0);
|
||||
int diff = out.ndim() - in.ndim();
|
||||
for (int i = in.ndim() - 1; i >= 0; --i) {
|
||||
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
|
||||
}
|
||||
auto flags = in.flags();
|
||||
if (out.size() > in.size()) {
|
||||
flags.row_contiguous = flags.col_contiguous = false;
|
||||
}
|
||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
@@ -87,7 +103,7 @@ void ExpandDims::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
double numel = 1;
|
||||
for (auto ax : axes_) {
|
||||
|
@@ -1,7 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -14,8 +15,6 @@ void print_constant(std::ostream& os, const array& x) {
|
||||
return print_float_constant<float16_t>(os, x);
|
||||
case bfloat16:
|
||||
return print_float_constant<bfloat16_t>(os, x);
|
||||
case float64:
|
||||
return print_float_constant<double>(os, x);
|
||||
case complex64:
|
||||
return print_complex_constant<complex64_t>(os, x);
|
||||
case int8:
|
||||
@@ -52,8 +51,6 @@ std::string get_type_string(Dtype d) {
|
||||
return "float16_t";
|
||||
case bfloat16:
|
||||
return "bfloat16_t";
|
||||
case float64:
|
||||
return "double";
|
||||
case complex64:
|
||||
return "complex64_t";
|
||||
case bool_:
|
||||
@@ -82,6 +79,55 @@ std::string get_type_string(Dtype d) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
NodeNamer namer;
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (auto& x : inputs) {
|
||||
namer.get_name(x);
|
||||
}
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// computation performed
|
||||
a.primitive().print(os);
|
||||
// name of inputs to the function
|
||||
for (auto& inp : a.inputs()) {
|
||||
os << namer.get_name(inp);
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << (is_scalar(x) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool compiled_check_contiguity(
|
||||
const std::vector<array>& inputs,
|
||||
const Shape& shape) {
|
||||
@@ -113,7 +159,8 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
int o = 0;
|
||||
@@ -128,7 +175,8 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() && is_constant(i)) {
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
@@ -140,7 +188,7 @@ void compiled_allocate_outputs(
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(
|
||||
allocator::malloc(data_size * outputs[o].itemsize()),
|
||||
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
|
||||
data_size,
|
||||
strides,
|
||||
flags);
|
||||
@@ -156,86 +204,16 @@ void compiled_allocate_outputs(
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
is_constant(i)) {
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
}
|
||||
}
|
||||
for (; o < outputs.size(); ++o) {
|
||||
outputs[o].set_data(allocator::malloc(outputs[o].nbytes()));
|
||||
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant) {
|
||||
const Shape& shape = out.shape();
|
||||
bool contiguous = compiled_check_contiguity(inputs, shape);
|
||||
if (contiguous) {
|
||||
return {true, shape, {}};
|
||||
}
|
||||
|
||||
std::vector<Strides> strides_vec{out.strides()};
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
// Skip constants.
|
||||
if (is_constant(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip scalar inputs.
|
||||
const auto& x = inputs[i];
|
||||
if (is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
Strides xstrides;
|
||||
size_t j = 0;
|
||||
for (; j < shape.size() - x.ndim(); ++j) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(out.strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides_vec.push_back(std::move(xstrides));
|
||||
}
|
||||
|
||||
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
|
||||
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
|
||||
}
|
||||
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
bool contiguous) {
|
||||
if (contiguous) {
|
||||
size_t max_size = 0;
|
||||
for (const auto& in : inputs) {
|
||||
max_size = std::max(max_size, in.data_size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
} else {
|
||||
size_t max_size = 0;
|
||||
for (const auto& o : outputs) {
|
||||
max_size = std::max(max_size, o.size());
|
||||
}
|
||||
return max_size > UINT32_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,8 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -13,17 +14,19 @@ inline bool is_static_cast(const Primitive& p) {
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids);
|
||||
|
||||
std::string get_type_string(Dtype d);
|
||||
|
||||
template <typename T>
|
||||
void print_float_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
if constexpr (std::is_same_v<T, double>) {
|
||||
os << std::setprecision(std::numeric_limits<double>::digits10 + 1);
|
||||
} else {
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1);
|
||||
}
|
||||
os << x.item<T>() << std::setprecision(old_precision);
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< x.item<T>() << std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -57,19 +60,8 @@ bool compiled_check_contiguity(
|
||||
void compiled_allocate_outputs(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
bool contiguous);
|
||||
|
||||
// Collapse contiguous dims ignoring scalars and constants.
|
||||
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
|
||||
const std::vector<array>& inputs,
|
||||
const array& out,
|
||||
const std::function<bool(size_t)>& is_constant);
|
||||
|
||||
// Return whether the kernel should use large index.
|
||||
bool compiled_use_large_index(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& inputs_,
|
||||
const std::unordered_set<uintptr_t>& constant_ids_,
|
||||
bool contiguous);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -26,19 +26,19 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (is_donatable(in, out)) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(in);
|
||||
return true;
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@@ -99,11 +99,7 @@ inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
if (n > (1 << 26)) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where k <= 26");
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
@@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) {
|
||||
namespace mlx::core {
|
||||
|
||||
void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto read_task = [out_ptr = out.data<char>(),
|
||||
size = out.size(),
|
||||
itemsize = out.itemsize(),
|
||||
|
@@ -1,67 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides> collapse_batches(
|
||||
const array& a,
|
||||
const array& b) {
|
||||
if (a.ndim() == 2) {
|
||||
return {{1}, {0}, {0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] =
|
||||
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
|
||||
|
||||
auto a_batch_strides = batch_strides[0];
|
||||
auto b_batch_strides = batch_strides[1];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
a_batch_strides.push_back(0);
|
||||
b_batch_strides.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides);
|
||||
}
|
||||
|
||||
inline std::tuple<Shape, Strides, Strides, Strides>
|
||||
collapse_batches(const array& a, const array& b, const array& c) {
|
||||
if (a.ndim() == 2) {
|
||||
return {{1}, {0}, {0}, {0}};
|
||||
}
|
||||
|
||||
Shape A_bshape{a.shape().begin(), a.shape().end() - 2};
|
||||
Strides A_bstride{a.strides().begin(), a.strides().end() - 2};
|
||||
Strides B_bstride{b.strides().begin(), b.strides().end() - 2};
|
||||
Strides C_bstride{c.strides().begin(), c.strides().end() - 2};
|
||||
|
||||
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
|
||||
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
|
||||
|
||||
auto A_batch_stride = batch_strides[0];
|
||||
auto B_batch_stride = batch_strides[1];
|
||||
auto C_batch_stride = batch_strides[2];
|
||||
|
||||
if (batch_shape.empty()) {
|
||||
batch_shape.push_back(1);
|
||||
A_batch_stride.push_back(0);
|
||||
B_batch_stride.push_back(0);
|
||||
C_batch_stride.push_back(0);
|
||||
}
|
||||
|
||||
return std::make_tuple(
|
||||
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -5,9 +5,11 @@
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
@@ -17,15 +19,6 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
auto shape = x.shape();
|
||||
auto strides = x.strides();
|
||||
return shapes_without_reduction_axes(
|
||||
std::move(shape), std::move(strides), axes);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
|
@@ -51,9 +51,5 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes);
|
||||
std::pair<Shape, Strides> shapes_without_reduction_axes(
|
||||
Shape shape,
|
||||
Strides strides,
|
||||
const std::vector<int>& axes);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -48,12 +48,12 @@ inline void set_ternary_op_output_data(
|
||||
switch (topt) {
|
||||
case TernaryOpType::ScalarScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc(out.itemsize()), 1, b.strides(), b.flags());
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags());
|
||||
break;
|
||||
case TernaryOpType::VectorVectorVector:
|
||||
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
||||
out.set_data(
|
||||
allocator::malloc(out.itemsize() * b.data_size()),
|
||||
allocator::malloc_or_wait(out.itemsize() * b.data_size()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
@@ -64,7 +64,7 @@ inline void set_ternary_op_output_data(
|
||||
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
||||
(b.flags().row_contiguous && maybe_donate(b)) ||
|
||||
(c.flags().row_contiguous && maybe_donate(c)))) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@@ -1,26 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.flags().contiguous) {
|
||||
if (is_donatable(in, out)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,22 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::filesystem::path current_binary_dir() {
|
||||
static std::filesystem::path binary_dir = []() {
|
||||
Dl_info info;
|
||||
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) {
|
||||
throw std::runtime_error("Unable to get current binary dir.");
|
||||
}
|
||||
return std::filesystem::path(info.dli_fname).parent_path();
|
||||
}();
|
||||
return binary_dir;
|
||||
}
|
||||
|
||||
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
||||
const Shape& shape,
|
||||
const std::vector<Strides>& strides,
|
||||
@@ -114,118 +101,4 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
|
||||
}
|
||||
|
||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
|
||||
int pows[3] = {0, 0, 0};
|
||||
int sum = 0;
|
||||
while (true) {
|
||||
int presum = sum;
|
||||
// Check all the pows
|
||||
if (dim0 >= (1 << (pows[0] + 1))) {
|
||||
pows[0]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim1 >= (1 << (pows[1] + 1))) {
|
||||
pows[1]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == 10) {
|
||||
break;
|
||||
}
|
||||
if (dim2 >= (1 << (pows[2] + 1))) {
|
||||
pows[2]++;
|
||||
sum++;
|
||||
}
|
||||
if (sum == presum || sum == pow2) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
|
||||
}
|
||||
|
||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
|
||||
// Dims with strides of 0 are ignored as they
|
||||
// correspond to broadcasted dimensions
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
Dims get_2d_grid_dims_common(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor) {
|
||||
// Compute the 2d grid dimensions such that the total size of the grid is
|
||||
// divided by divisor.
|
||||
size_t grid_x = 1;
|
||||
size_t grid_y = 1;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (strides[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// No need to add this shape we can just remove it from the divisor.
|
||||
if (divisor % shape[i] == 0) {
|
||||
divisor /= shape[i];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (grid_x * shape[i] < UINT32_MAX) {
|
||||
grid_x *= shape[i];
|
||||
} else {
|
||||
grid_y *= shape[i];
|
||||
}
|
||||
|
||||
if (divisor > 1) {
|
||||
if (grid_x % divisor == 0) {
|
||||
grid_x /= divisor;
|
||||
divisor = 1;
|
||||
} else if (grid_y % divisor == 0) {
|
||||
grid_y /= divisor;
|
||||
divisor = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
|
||||
throw std::runtime_error("Unable to safely factor shape.");
|
||||
}
|
||||
if (grid_y > grid_x) {
|
||||
std::swap(grid_x, grid_y);
|
||||
}
|
||||
if (divisor > 1) {
|
||||
grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
|
||||
}
|
||||
return std::make_tuple(
|
||||
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
|
||||
}
|
||||
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
||||
auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
|
||||
auto gx = (dim0 + bx - 1) / bx;
|
||||
auto gy = (dim1 + by - 1) / by;
|
||||
auto gz = (dim2 + bz - 1) / bz;
|
||||
|
||||
return std::make_pair(
|
||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -2,17 +2,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <filesystem>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Return the directory that contains current shared library.
|
||||
std::filesystem::path current_binary_dir();
|
||||
|
||||
inline int64_t
|
||||
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
||||
int64_t loc = 0;
|
||||
@@ -75,31 +70,6 @@ std::pair<Shape, Strides> collapse_contiguous_dims(
|
||||
const array& a,
|
||||
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
||||
|
||||
// Compute the thread block dimensions which fit the given
|
||||
// input dimensions.
|
||||
// - The thread block dimensions will be powers of two
|
||||
// - The thread block size will be less than 2^pow2
|
||||
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
||||
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
||||
|
||||
// Computes a 2D grid where each element is < UINT_MAX
|
||||
// Assumes:
|
||||
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
||||
// - shape and strides correspond to a contiguous (no holes) but
|
||||
// possibly broadcasted array
|
||||
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
||||
|
||||
// Same as above but we do an implicit division with divisor.
|
||||
// Basically, equivalent to factorizing
|
||||
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
||||
Dims get_2d_grid_dims_common(
|
||||
const Shape& shape,
|
||||
const Strides& strides,
|
||||
size_t divisor);
|
||||
|
||||
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
||||
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
||||
|
||||
struct ContiguousIterator {
|
||||
inline void step() {
|
||||
int dims = shape_.size();
|
||||
@@ -195,11 +165,4 @@ void shared_buffer_reshape(
|
||||
const array& in,
|
||||
const Strides& out_strides,
|
||||
array& out);
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
return vec;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -40,13 +40,11 @@ add_dependencies(mlx cpu_compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/available.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eig.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/encoder.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
@@ -60,7 +58,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
@@ -76,8 +73,8 @@ target_sources(
|
||||
if(MLX_BUILD_ACCELERATE)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/bnns.cpp)
|
||||
else()
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/simd_bf16.cpp)
|
||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_fp16.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemms/no_bf16.cpp)
|
||||
endif()
|
||||
|
||||
if(IOS)
|
||||
|
@@ -11,24 +11,43 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename InT, typename OpT>
|
||||
void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
|
||||
void arg_reduce(
|
||||
const array& in,
|
||||
array& out,
|
||||
const OpT& op,
|
||||
int axis,
|
||||
Stream stream) {
|
||||
auto axis_size = in.shape()[axis];
|
||||
auto axis_stride = in.strides()[axis];
|
||||
Strides strides = remove_index(in.strides(), axis);
|
||||
Shape shape = remove_index(in.shape(), axis);
|
||||
Strides strides = in.strides();
|
||||
Shape shape = in.shape();
|
||||
strides.erase(strides.begin() + axis);
|
||||
shape.erase(shape.begin() + axis);
|
||||
auto in_ptr = in.data<InT>();
|
||||
auto out_ptr = out.data<uint32_t>();
|
||||
|
||||
for (uint32_t i = 0; i < out.size(); ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto local_in_ptr = in_ptr + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*local_in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
op(j, (*local_in_ptr), &ind_v, &v);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
axis_size,
|
||||
axis_stride,
|
||||
op = std::move(op),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides),
|
||||
size = out.size()]() {
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
auto loc = elem_to_loc(i, shape, strides);
|
||||
auto local_in_ptr = in_ptr + loc;
|
||||
uint32_t ind_v = 0;
|
||||
InT v = (*local_in_ptr);
|
||||
for (uint32_t j = 0; j < axis_size; ++j, local_in_ptr += axis_stride) {
|
||||
op(j, (*local_in_ptr), &ind_v, &v);
|
||||
}
|
||||
out_ptr[i] = ind_v;
|
||||
}
|
||||
out_ptr[i] = ind_v;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename InT>
|
||||
@@ -36,7 +55,8 @@ void arg_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
ArgReduce::ReduceType rtype,
|
||||
int axis) {
|
||||
int axis,
|
||||
Stream stream) {
|
||||
switch (rtype) {
|
||||
case ArgReduce::ArgMin: {
|
||||
auto op = [](auto ind_x, auto x, auto ind_y, auto y) {
|
||||
@@ -45,7 +65,7 @@ void arg_reduce_dispatch(
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
arg_reduce<InT>(in, out, op, axis, stream);
|
||||
break;
|
||||
}
|
||||
case ArgReduce::ArgMax: {
|
||||
@@ -55,7 +75,7 @@ void arg_reduce_dispatch(
|
||||
(*ind_y) = ind_x;
|
||||
}
|
||||
};
|
||||
arg_reduce<InT>(in, out, op, axis);
|
||||
arg_reduce<InT>(in, out, op, axis, stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -66,59 +86,52 @@ void arg_reduce_dispatch(
|
||||
void ArgReduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
reduce_type_ = reduce_type_,
|
||||
axis_ = axis_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint8:
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint16:
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint32:
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case uint64:
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int8:
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int16:
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int32:
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case int64:
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float16:
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float32:
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case float64:
|
||||
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_);
|
||||
break;
|
||||
}
|
||||
});
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
arg_reduce_dispatch<bool>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case uint8:
|
||||
arg_reduce_dispatch<uint8_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
arg_reduce_dispatch<uint16_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
arg_reduce_dispatch<uint32_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
arg_reduce_dispatch<uint64_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case int8:
|
||||
arg_reduce_dispatch<int8_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case int16:
|
||||
arg_reduce_dispatch<int16_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case int32:
|
||||
arg_reduce_dispatch<int32_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case int64:
|
||||
arg_reduce_dispatch<int64_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case float16:
|
||||
arg_reduce_dispatch<float16_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case float32:
|
||||
arg_reduce_dispatch<float>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
arg_reduce_dispatch<bfloat16_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case float64:
|
||||
arg_reduce_dispatch<double>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
arg_reduce_dispatch<complex64_t>(in, out, reduce_type_, axis_, stream());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,11 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/available.h"
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
@@ -1,9 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
bool is_available();
|
||||
|
||||
} // namespace mlx::core::cpu
|
@@ -8,7 +8,6 @@
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/binary_two.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
@@ -17,221 +16,51 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op, Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void comparison_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, Op>(a, b, out, bopt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_float(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[binary_float] Only supports floating point types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary_int(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out, bopt);
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[binary_int] Type not supported");
|
||||
break;
|
||||
}
|
||||
});
|
||||
void comparison_op(const array& a, const array& b, array& out) {
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, Op>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, Op>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, Op>(a, b, out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -240,7 +69,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Add(), stream());
|
||||
binary(a, b, out, detail::Add());
|
||||
}
|
||||
|
||||
void DivMod::eval_cpu(
|
||||
@@ -249,89 +78,70 @@ void DivMod::eval_cpu(
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out_a = array::unsafe_weak_copy(out_a),
|
||||
out_b = array::unsafe_weak_copy(out_b),
|
||||
bopt]() mutable {
|
||||
auto integral_op = [](auto x, auto y) {
|
||||
return std::make_pair(x / y, x % y);
|
||||
};
|
||||
auto float_op = [](auto x, auto y) {
|
||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||
};
|
||||
|
||||
switch (out_a.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out_a, out_b, integral_op, bopt);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out_a, out_b, integral_op, bopt);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out_a, out_b, float_op, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
// Should never get here
|
||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||
break;
|
||||
}
|
||||
});
|
||||
auto integral_op = [](auto x, auto y) {
|
||||
return std::make_pair(x / y, x % y);
|
||||
};
|
||||
auto float_op = [](auto x, auto y) {
|
||||
return std::make_pair(std::trunc(x / y), std::fmod(x, y));
|
||||
};
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, integral_op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, integral_op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, float_op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, outputs, float_op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, float_op);
|
||||
break;
|
||||
case complex64:
|
||||
// Should never get here
|
||||
throw std::runtime_error("[DivMod] Complex type not supported");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Divide(), stream());
|
||||
binary(a, b, out, detail::Divide());
|
||||
}
|
||||
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Remainder(), stream());
|
||||
binary(a, b, out, detail::Remainder());
|
||||
}
|
||||
|
||||
void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -339,143 +149,181 @@ void Equal::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (equal_nan_) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
bopt]() mutable {
|
||||
switch (a.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out, bopt);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[NanEqual::eval_cpu] Only for floating point types.");
|
||||
}
|
||||
});
|
||||
switch (a.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, bool, detail::NaNEqual>(a, b, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[NanEqual::eval_cpu] Only for floating point types.");
|
||||
}
|
||||
} else {
|
||||
comparison_op(a, b, out, detail::Equal(), stream());
|
||||
comparison_op<detail::Equal>(a, b, out);
|
||||
}
|
||||
}
|
||||
|
||||
void Greater::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Greater(), stream());
|
||||
comparison_op<detail::Greater>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::GreaterEqual(), stream());
|
||||
comparison_op<detail::GreaterEqual>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void Less::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::Less(), stream());
|
||||
comparison_op<detail::Less>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void LessEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::LessEqual(), stream());
|
||||
comparison_op<detail::LessEqual>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void LogAddExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary_float(a, b, out, detail::LogAddExp(), stream());
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, detail::LogAddExp>(a, b, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[LogAddExp::eval_cpu] Only supports non-complex floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd(), stream());
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr(), stream());
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Maximum(), stream());
|
||||
binary(a, b, out, detail::Maximum());
|
||||
}
|
||||
|
||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Minimum(), stream());
|
||||
binary(a, b, out, detail::Minimum());
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Multiply(), stream());
|
||||
binary(a, b, out, detail::Multiply());
|
||||
}
|
||||
|
||||
void NotEqual::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
comparison_op(inputs[0], inputs[1], out, detail::NotEqual(), stream());
|
||||
comparison_op<detail::NotEqual>(inputs[0], inputs[1], out);
|
||||
}
|
||||
|
||||
void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Power(), stream());
|
||||
binary(a, b, out, detail::Power());
|
||||
}
|
||||
|
||||
void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, detail::Subtract(), stream());
|
||||
binary(a, b, out, detail::Subtract());
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto dispatch_type = [&a, &b, &out](auto op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, out, op);
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, out, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, out, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, out, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, out, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, out, op);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[BitwiseBinary::eval_cpu] Type not supported");
|
||||
break;
|
||||
}
|
||||
};
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_int(a, b, out, detail::BitwiseAnd(), stream());
|
||||
dispatch_type(detail::BitwiseAnd());
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_int(a, b, out, detail::BitwiseOr(), stream());
|
||||
dispatch_type(detail::BitwiseOr());
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_int(a, b, out, detail::BitwiseXor(), stream());
|
||||
dispatch_type(detail::BitwiseXor());
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_int(a, b, out, detail::LeftShift(), stream());
|
||||
dispatch_type(detail::LeftShift());
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_int(a, b, out, detail::RightShift(), stream());
|
||||
dispatch_type(detail::RightShift());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -484,7 +332,23 @@ void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
const auto& a = inputs[0];
|
||||
const auto& b = inputs[1];
|
||||
binary_float(a, b, out, detail::ArcTan2(), stream());
|
||||
switch (out.dtype()) {
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[ArcTan2::eval_cpu] Only supports non-complex floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -3,9 +3,12 @@
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
@@ -149,145 +152,218 @@ void binary_op_dispatch_dims(
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
||||
void binary_op(const array& a, const array& b, array& out) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
auto a_ptr = a.data<T>();
|
||||
auto b_ptr = b.data<T>();
|
||||
|
||||
auto out_ptr = out.data<U>();
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*out_ptr = Op{}(*a_ptr, *b_ptr);
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
|
||||
return;
|
||||
}
|
||||
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out.strides()});
|
||||
auto& a_strides = new_strides[0];
|
||||
auto& b_strides = new_strides[1];
|
||||
auto& strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
auto& encoder = cpu::get_command_encoder(out.primitive().stream());
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([bopt,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
a_data_size = a.data_size(),
|
||||
b_data_size = b.data_size(),
|
||||
size = a.size(),
|
||||
shape = a.shape(),
|
||||
a_strides = a.strides(),
|
||||
b_strides = b.strides(),
|
||||
strides = out.strides()]() mutable {
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
*out_ptr = Op{}(*a_ptr, *b_ptr);
|
||||
return;
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
// The full computation is scalar vector so delegate to the op
|
||||
if (bopt == BinaryOpType::ScalarVector) {
|
||||
ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b_data_size);
|
||||
return;
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = new_shape.size();
|
||||
// The full computation is vector scalar so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorScalar) {
|
||||
VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a_data_size);
|
||||
return;
|
||||
}
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// The full computation is vector vector so delegate to the op
|
||||
if (bopt == BinaryOpType::VectorVector) {
|
||||
VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, size);
|
||||
return;
|
||||
}
|
||||
|
||||
// General computation so let's try to optimize
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
shape,
|
||||
{std::move(a_strides), std::move(b_strides), std::move(strides)});
|
||||
a_strides = new_strides[0];
|
||||
b_strides = new_strides[1];
|
||||
strides = new_strides[2];
|
||||
|
||||
// Get the left-most dim such that the array is row contiguous after
|
||||
auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_rc_dim = leftmost_rc_dim(a_strides);
|
||||
auto b_rc_dim = leftmost_rc_dim(b_strides);
|
||||
|
||||
// Get the left-most dim such that the array is a broadcasted "scalar" after
|
||||
auto leftmost_s_dim = [](const auto& arr_strides) {
|
||||
int d = arr_strides.size() - 1;
|
||||
for (; d >= 0 && arr_strides[d] == 0; d--) {
|
||||
}
|
||||
return d + 1;
|
||||
};
|
||||
auto a_s_dim = leftmost_s_dim(a_strides);
|
||||
auto b_s_dim = leftmost_s_dim(b_strides);
|
||||
|
||||
auto ndim = new_shape.size();
|
||||
|
||||
// Case 1: LxM and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
int dim = ndim;
|
||||
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorVector;
|
||||
dim = d;
|
||||
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
|
||||
bopt = BinaryOpType::VectorScalar;
|
||||
dim = d;
|
||||
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
|
||||
// contiguous
|
||||
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
|
||||
bopt = BinaryOpType::ScalarVector;
|
||||
dim = d;
|
||||
}
|
||||
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
}
|
||||
// Can be sure dim > 0 since otherwise we would have used one of the fully
|
||||
// contiguous methods above. Except for the case that the flags do not
|
||||
// correspond to the underlying contiguity.
|
||||
if (dim == 0 || strides[dim - 1] < 16) {
|
||||
bopt = BinaryOpType::General;
|
||||
dim = ndim;
|
||||
}
|
||||
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false, Op>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
a.size(),
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
}
|
||||
switch (bopt) {
|
||||
case BinaryOpType::VectorVector:
|
||||
binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
default:
|
||||
binary_op_dispatch_dims<T, U, false, Op>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
dim,
|
||||
size,
|
||||
new_shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
strides);
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
|
||||
binary_op<T, T, Op>(a, b, out, bopt);
|
||||
void binary_op(const array& a, const array& b, array& out) {
|
||||
binary_op<T, T, Op>(a, b, out);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void binary_op(const array& a, const array& b, array& out, Op op) {
|
||||
binary_op<T, T, Op>(a, b, out);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary(const array& a, const array& b, array& out, Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool, Op>(a, b, out);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t, Op>(a, b, out);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t, Op>(a, b, out);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t, Op>(a, b, out);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t, Op>(a, b, out);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t, Op>(a, b, out);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t, Op>(a, b, out);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t, Op>(a, b, out);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t, Op>(a, b, out);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t, Op>(a, b, out);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float, Op>(a, b, out);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double, Op>(a, b, out);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t, Op>(a, b, out);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t, Op>(a, b, out);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -4,6 +4,8 @@
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/binary.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -55,7 +57,14 @@ void binary_op_dispatch_dims(
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Stream stream,
|
||||
Op op) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
a.shape(), {a.strides(), b.strides(), out_a.strides()});
|
||||
const T* a_ptr = a.data<T>();
|
||||
@@ -63,101 +72,197 @@ void binary_op_dispatch_dims(
|
||||
U* out_a_ptr = out_a.data<U>();
|
||||
U* out_b_ptr = out_b.data<U>();
|
||||
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& out_strides = strides[2];
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = a.size(),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides),
|
||||
op = std::move(op)]() {
|
||||
const auto& a_strides = strides[0];
|
||||
const auto& b_strides = strides[1];
|
||||
const auto& out_strides = strides[2];
|
||||
int ndim = shape.size();
|
||||
switch (ndim) {
|
||||
case 1:
|
||||
binary_op_dims<T, U, Op, 1>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
case 2:
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
0);
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < a.size(); elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_a_ptr + elem,
|
||||
out_b_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
ndim - 2);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
ContiguousIterator a_it(shape, a_strides, ndim - 2);
|
||||
ContiguousIterator b_it(shape, b_strides, ndim - 2);
|
||||
auto stride = out_strides[ndim - 3];
|
||||
for (size_t elem = 0; elem < size; elem += stride) {
|
||||
binary_op_dims<T, U, Op, 2>(
|
||||
a_ptr + a_it.loc,
|
||||
b_ptr + b_it.loc,
|
||||
out_a_ptr + elem,
|
||||
out_b_ptr + elem,
|
||||
op,
|
||||
shape,
|
||||
a_strides,
|
||||
b_strides,
|
||||
out_strides,
|
||||
ndim - 2);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename U = T, typename Op>
|
||||
void binary_op(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out_a,
|
||||
array& out_b,
|
||||
Op op,
|
||||
BinaryOpType bopt) {
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
auto& out_a = outputs[0];
|
||||
auto& out_b = outputs[1];
|
||||
set_binary_op_output_data(a, b, out_a, bopt);
|
||||
set_binary_op_output_data(a, b, out_b, bopt);
|
||||
|
||||
auto stream = out_a.primitive().stream();
|
||||
// The full computation is scalar scalar so call the base op once
|
||||
if (bopt == BinaryOpType::General) {
|
||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
|
||||
binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, stream, op);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out_a);
|
||||
encoder.set_output_array(out_b);
|
||||
|
||||
auto a_ptr = a.data<T>();
|
||||
auto b_ptr = b.data<T>();
|
||||
auto out_a_ptr = out_a.data<U>();
|
||||
auto out_b_ptr = out_b.data<U>();
|
||||
if (bopt == BinaryOpType::ScalarScalar) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
encoder.dispatch(
|
||||
[a_ptr, b_ptr, out_a_ptr, out_b_ptr, op = std::move(op)]() mutable {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
});
|
||||
} else if (bopt == BinaryOpType::ScalarVector) {
|
||||
for (size_t i = 0; i < b.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = b.size(),
|
||||
op = std::move(op)]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
});
|
||||
} else if (bopt == BinaryOpType::VectorScalar) {
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = a.size(),
|
||||
op = std::move(op)]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
}
|
||||
});
|
||||
} else { // VectorVector
|
||||
for (size_t i = 0; i < a.size(); ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
b_ptr,
|
||||
out_a_ptr,
|
||||
out_b_ptr,
|
||||
size = a.size(),
|
||||
op = std::move(op)]() mutable {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
|
||||
out_a_ptr++;
|
||||
out_b_ptr++;
|
||||
a_ptr++;
|
||||
b_ptr++;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void binary(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<array>& outputs,
|
||||
Op op) {
|
||||
switch (outputs[0].dtype()) {
|
||||
case bool_:
|
||||
binary_op<bool>(a, b, outputs, op);
|
||||
break;
|
||||
case uint8:
|
||||
binary_op<uint8_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint16:
|
||||
binary_op<uint16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint32:
|
||||
binary_op<uint32_t>(a, b, outputs, op);
|
||||
break;
|
||||
case uint64:
|
||||
binary_op<uint64_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int8:
|
||||
binary_op<int8_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int16:
|
||||
binary_op<int16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int32:
|
||||
binary_op<int32_t>(a, b, outputs, op);
|
||||
break;
|
||||
case int64:
|
||||
binary_op<int64_t>(a, b, outputs, op);
|
||||
break;
|
||||
case float16:
|
||||
binary_op<float16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case float32:
|
||||
binary_op<float>(a, b, outputs, op);
|
||||
break;
|
||||
case float64:
|
||||
binary_op<double>(a, b, outputs, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
binary_op<bfloat16_t>(a, b, outputs, op);
|
||||
break;
|
||||
case complex64:
|
||||
binary_op<complex64_t>(a, b, outputs, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -20,7 +20,7 @@ void cholesky_impl(const array& a, array& factor, bool upper, Stream stream) {
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy_cpu(
|
||||
copy(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
|
@@ -40,10 +40,7 @@ struct CompilerCache {
|
||||
std::shared_mutex mtx;
|
||||
};
|
||||
|
||||
static CompilerCache& cache() {
|
||||
static CompilerCache cache_;
|
||||
return cache_;
|
||||
};
|
||||
static CompilerCache cache{};
|
||||
|
||||
// GPU compile is always available if the GPU is available and since we are in
|
||||
// this file CPU compile is also available.
|
||||
@@ -59,16 +56,14 @@ void* compile(
|
||||
const std::string& kernel_name,
|
||||
const std::function<std::string(void)>& source_builder) {
|
||||
{
|
||||
std::shared_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
std::shared_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_lock lock(cache().mtx);
|
||||
if (auto it = cache().kernels.find(kernel_name);
|
||||
it != cache().kernels.end()) {
|
||||
std::unique_lock lock(cache.mtx);
|
||||
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string source_code = source_builder();
|
||||
@@ -125,10 +120,10 @@ void* compile(
|
||||
}
|
||||
|
||||
// load library
|
||||
cache().libs.emplace_back(shared_lib_path);
|
||||
cache.libs.emplace_back(shared_lib_path);
|
||||
|
||||
// Load function
|
||||
void* fun = dlsym(cache().libs.back().lib, kernel_name.c_str());
|
||||
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
|
||||
if (!fun) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Compile::eval_cpu] Failed to load compiled function "
|
||||
@@ -136,7 +131,7 @@ void* compile(
|
||||
<< dlerror();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
cache().kernels.insert({kernel_name, fun});
|
||||
cache.kernels.insert({kernel_name, fun});
|
||||
return fun;
|
||||
}
|
||||
|
||||
@@ -146,9 +141,18 @@ inline void build_kernel(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::function<bool(size_t)>& is_constant,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
bool contiguous,
|
||||
int ndim) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
NodeNamer namer;
|
||||
|
||||
#ifdef _MSC_VER
|
||||
@@ -161,15 +165,14 @@ inline void build_kernel(
|
||||
|
||||
// Add the input arguments
|
||||
int cnt = 0;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
// Skip constants from the input list
|
||||
if (is_constant(i)) {
|
||||
if (is_constant(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
auto tstr = get_type_string(x.dtype());
|
||||
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
|
||||
<< "];" << std::endl;
|
||||
@@ -203,11 +206,10 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(i)) {
|
||||
if (is_constant(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
@@ -231,7 +233,7 @@ inline void build_kernel(
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
} else {
|
||||
os << x.primitive().name();
|
||||
x.primitive().print(os);
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
@@ -257,9 +259,8 @@ inline void build_kernel(
|
||||
} else {
|
||||
for (int d = ndim - 1; d >= 0; --d) {
|
||||
// Update pointers
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
const auto& x = inputs[i];
|
||||
if (is_constant(i) || is_scalar(x)) {
|
||||
for (auto& x : inputs) {
|
||||
if (is_constant(x) || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
auto& xname = namer.get_name(x);
|
||||
@@ -281,37 +282,65 @@ inline void build_kernel(
|
||||
void Compiled::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& shape = outputs[0].shape();
|
||||
auto contiguous = compiled_check_contiguity(inputs, shape);
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Collect function input arguments.
|
||||
// Handle all broadcasting and collect function input arguments
|
||||
std::vector<void*> args;
|
||||
int strides_index = 1;
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (is_constant_(i)) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
const auto& x = inputs[i];
|
||||
auto& x = inputs[i];
|
||||
encoder.set_input_array(x);
|
||||
args.push_back((void*)x.data<void>());
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
args.push_back(strides[strides_index++].data());
|
||||
|
||||
if (contiguous || is_scalar(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the input to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < shape.size() - x.ndim(); j++) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
strides.push_back(std::move(xstrides));
|
||||
args.push_back(strides.back().data());
|
||||
}
|
||||
|
||||
// Get the kernel name from the lib
|
||||
int ndim = shape.size();
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
kernel_name += std::to_string(ndim);
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
|
||||
// Get the function
|
||||
auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
|
||||
auto fn_ptr = compile(kernel_name, [&]() {
|
||||
std::ostringstream kernel;
|
||||
kernel << get_kernel_preamble() << std::endl;
|
||||
kernel << "extern \"C\" {" << std::endl;
|
||||
@@ -321,7 +350,7 @@ void Compiled::eval_cpu(
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
is_constant_,
|
||||
constant_ids_,
|
||||
contiguous,
|
||||
ndim);
|
||||
// Close extern "C"
|
||||
@@ -329,22 +358,26 @@ void Compiled::eval_cpu(
|
||||
return kernel.str();
|
||||
});
|
||||
|
||||
compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
|
||||
for (auto& x : outputs) {
|
||||
args.push_back(x.data<void>());
|
||||
encoder.set_output_array(x);
|
||||
}
|
||||
Shape out_shape;
|
||||
if (!contiguous) {
|
||||
args.push_back((void*)shape.data());
|
||||
out_shape = outputs[0].shape();
|
||||
args.push_back((void*)out_shape.data());
|
||||
} else {
|
||||
args.push_back((void*)outputs[0].data_size());
|
||||
}
|
||||
auto fun = (void (*)(void**))fn_ptr;
|
||||
encoder.dispatch([fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
shape = std::move(shape)]() mutable { fun(args.data()); });
|
||||
encoder.dispatch(
|
||||
[fun,
|
||||
args = std::move(args),
|
||||
strides = std::move(strides),
|
||||
out_shape = std::move(out_shape)]() mutable { fun(args.data()); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -13,20 +13,29 @@ namespace mlx::core {
|
||||
namespace {
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_single(const array& src, array& dst) {
|
||||
void copy_single(const array& src, array& dst, Stream stream) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
auto size = dst.size();
|
||||
auto val = static_cast<DstT>(src_ptr[0]);
|
||||
std::fill_n(dst_ptr, size, val);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch([src_ptr, dst_ptr, size = dst.size()]() {
|
||||
auto val = static_cast<DstT>(src_ptr[0]);
|
||||
std::fill_n(dst_ptr, size, val);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
void copy_vector(const array& src, array& dst) {
|
||||
void copy_vector(const array& src, array& dst, Stream stream) {
|
||||
auto src_ptr = src.data<SrcT>();
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
auto size = src.data_size();
|
||||
std::copy(src_ptr, src_ptr + size, dst_ptr);
|
||||
size_t size = src.data_size();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch([src_ptr, dst_ptr, size = src.data_size()]() {
|
||||
std::copy(src_ptr, src_ptr + size, dst_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, int D>
|
||||
@@ -57,6 +66,7 @@ template <typename SrcT, typename DstT>
|
||||
void copy_general_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
Stream stream,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides& o_strides,
|
||||
@@ -70,17 +80,47 @@ void copy_general_general(
|
||||
dynamic_i_offset ? dynamic_i_offset->data<int64_t>() : nullptr;
|
||||
auto o_offset_ptr =
|
||||
dynamic_o_offset ? dynamic_o_offset->data<int64_t>() : nullptr;
|
||||
auto size = src.size();
|
||||
if (data_shape.empty()) {
|
||||
auto val = static_cast<DstT>(*src_ptr);
|
||||
*dst_ptr = val;
|
||||
return;
|
||||
}
|
||||
auto [shape, strides] =
|
||||
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
||||
|
||||
int ndim = shape.size();
|
||||
if (ndim < 3) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch([src_ptr,
|
||||
dst_ptr,
|
||||
size = src.size(),
|
||||
data_shape = data_shape,
|
||||
i_strides = i_strides,
|
||||
o_strides = o_strides,
|
||||
i_offset_ptr,
|
||||
o_offset_ptr]() mutable {
|
||||
if (data_shape.empty()) {
|
||||
auto val = static_cast<DstT>(*src_ptr);
|
||||
*dst_ptr = val;
|
||||
return;
|
||||
}
|
||||
auto [shape, strides] =
|
||||
collapse_contiguous_dims(data_shape, {i_strides, o_strides});
|
||||
|
||||
int ndim = shape.size();
|
||||
if (ndim < 3) {
|
||||
if (i_offset_ptr) {
|
||||
src_ptr += i_offset_ptr[0];
|
||||
}
|
||||
if (o_offset_ptr) {
|
||||
dst_ptr += o_offset_ptr[0];
|
||||
}
|
||||
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
} else if (ndim == 3) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (i_offset_ptr) {
|
||||
src_ptr += i_offset_ptr[0];
|
||||
}
|
||||
@@ -88,47 +128,30 @@ void copy_general_general(
|
||||
dst_ptr += o_offset_ptr[0];
|
||||
}
|
||||
|
||||
if (ndim == 1) {
|
||||
copy_dims<SrcT, DstT, 1>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
} else if (ndim == 2) {
|
||||
copy_dims<SrcT, DstT, 2>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
} else if (ndim == 3) {
|
||||
ContiguousIterator in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator out(shape, strides[1], ndim - 3);
|
||||
auto stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
||||
for (int64_t elem = 0; elem < size; elem += stride) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
|
||||
src_ptr + in.loc,
|
||||
dst_ptr + out.loc,
|
||||
shape,
|
||||
strides[0],
|
||||
strides[1],
|
||||
ndim - 3);
|
||||
in.step();
|
||||
out.step();
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (i_offset_ptr) {
|
||||
src_ptr += i_offset_ptr[0];
|
||||
}
|
||||
if (o_offset_ptr) {
|
||||
dst_ptr += o_offset_ptr[0];
|
||||
}
|
||||
|
||||
ContiguousIterator in(shape, strides[0], ndim - 3);
|
||||
ContiguousIterator out(shape, strides[1], ndim - 3);
|
||||
auto stride = std::accumulate(
|
||||
shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
|
||||
for (int64_t elem = 0; elem < size; elem += stride) {
|
||||
copy_dims<SrcT, DstT, 3>(
|
||||
src_ptr + in.loc,
|
||||
dst_ptr + out.loc,
|
||||
shape,
|
||||
strides[0],
|
||||
strides[1],
|
||||
ndim - 3);
|
||||
in.step();
|
||||
out.step();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general_general(const array& src, array& dst) {
|
||||
inline void copy_general_general(const array& src, array& dst, Stream stream) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
stream,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
dst.strides(),
|
||||
@@ -142,6 +165,7 @@ template <typename SrcT, typename DstT>
|
||||
void copy_general(
|
||||
const array& src,
|
||||
array& dst,
|
||||
Stream stream,
|
||||
const Shape& data_shape,
|
||||
const Strides& i_strides,
|
||||
const Strides&,
|
||||
@@ -152,6 +176,7 @@ void copy_general(
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
stream,
|
||||
data_shape,
|
||||
i_strides,
|
||||
make_contiguous_strides(data_shape),
|
||||
@@ -162,10 +187,11 @@ void copy_general(
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT>
|
||||
inline void copy_general(const array& src, array& dst) {
|
||||
inline void copy_general(const array& src, array& dst, Stream stream) {
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src,
|
||||
dst,
|
||||
stream,
|
||||
src.shape(),
|
||||
src.strides(),
|
||||
make_contiguous_strides(src.shape()),
|
||||
@@ -176,67 +202,84 @@ inline void copy_general(const array& src, array& dst) {
|
||||
}
|
||||
|
||||
template <typename SrcT, typename DstT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
void copy(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
Args&&... args) {
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
copy_single<SrcT, DstT>(src, dst);
|
||||
copy_single<SrcT, DstT>(src, dst, stream);
|
||||
return;
|
||||
case CopyType::Vector:
|
||||
copy_vector<SrcT, DstT>(src, dst);
|
||||
copy_vector<SrcT, DstT>(src, dst, stream);
|
||||
return;
|
||||
case CopyType::General:
|
||||
copy_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
copy_general<SrcT, DstT>(src, dst, stream, std::forward<Args>(args)...);
|
||||
return;
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_general_general<SrcT, DstT>(src, dst, std::forward<Args>(args)...);
|
||||
copy_general_general<SrcT, DstT>(
|
||||
src, dst, stream, std::forward<Args>(args)...);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcT, typename... Args>
|
||||
void copy(const array& src, array& dst, CopyType ctype, Args&&... args) {
|
||||
void copy(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
Args&&... args) {
|
||||
switch (dst.dtype()) {
|
||||
case bool_:
|
||||
copy<SrcT, bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<SrcT, uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint16_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<SrcT, uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint32_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<SrcT, uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, uint64_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<SrcT, int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<SrcT, int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<SrcT, int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<SrcT, int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<SrcT, float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, float16_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<SrcT, float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, float>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float64:
|
||||
copy<SrcT, double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, double>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<SrcT, bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, bfloat16_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<SrcT, complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<SrcT, complex64_t>(
|
||||
src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -246,70 +289,61 @@ inline void copy_inplace_dispatch(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream,
|
||||
Args&&... args) {
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
copy<bool>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<bool>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint8:
|
||||
copy<uint8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<uint8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint16:
|
||||
copy<uint16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<uint16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint32:
|
||||
copy<uint32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<uint32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case uint64:
|
||||
copy<uint64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<uint64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int8:
|
||||
copy<int8_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<int8_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int16:
|
||||
copy<int16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<int16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int32:
|
||||
copy<int32_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<int32_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case int64:
|
||||
copy<int64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<int64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float16:
|
||||
copy<float16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<float16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float32:
|
||||
copy<float>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<float>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case float64:
|
||||
copy<double>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<double>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case bfloat16:
|
||||
copy<bfloat16_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<bfloat16_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
case complex64:
|
||||
copy<complex64_t>(src, dst, ctype, std::forward<Args>(args)...);
|
||||
copy<complex64_t>(src, dst, ctype, stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
encoder.dispatch(
|
||||
[src = array::unsafe_weak_copy(src),
|
||||
dst = array::unsafe_weak_copy(dst),
|
||||
ctype]() mutable { copy_inplace_dispatch(src, dst, ctype); });
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
copy_inplace_dispatch(src, dst, ctype, stream);
|
||||
}
|
||||
|
||||
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
bool donated = set_copy_output_data(src, dst, ctype);
|
||||
if (donated && src.dtype() == dst.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
@@ -319,10 +353,10 @@ void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream) {
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_cpu_inplace(src, dst, ctype, stream);
|
||||
copy_inplace(src, dst, ctype, stream);
|
||||
}
|
||||
|
||||
void copy_cpu_inplace(
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
@@ -334,47 +368,26 @@ void copy_cpu_inplace(
|
||||
Stream stream,
|
||||
const std::optional<array>& dynamic_i_offset, /* = std::nullopt */
|
||||
const std::optional<array>& dynamic_o_offset /* = std::nullopt */) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_output_array(dst);
|
||||
auto weak_copy_if_set = [](auto x) -> std::optional<array> {
|
||||
if (x) {
|
||||
return array::unsafe_weak_copy(*x);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
};
|
||||
encoder.dispatch(
|
||||
[src = array::unsafe_weak_copy(src),
|
||||
dst = array::unsafe_weak_copy(dst),
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset,
|
||||
ctype,
|
||||
dynamic_i_offset = weak_copy_if_set(dynamic_i_offset),
|
||||
dynamic_o_offset = weak_copy_if_set(dynamic_o_offset)]() mutable {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset,
|
||||
dynamic_i_offset,
|
||||
dynamic_o_offset);
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
});
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
stream,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset,
|
||||
dynamic_i_offset,
|
||||
dynamic_o_offset);
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
copy_inplace_dispatch(src, dst, ctype, stream);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -10,14 +10,10 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
void copy_cpu_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
CopyType ctype,
|
||||
Stream stream);
|
||||
void copy(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype, Stream stream);
|
||||
|
||||
void copy_cpu_inplace(
|
||||
void copy_inplace(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const Shape& data_shape,
|
||||
|
@@ -14,7 +14,7 @@ std::pair<array, bool> ensure_row_contiguous(const array& arr, Stream stream) {
|
||||
return {arr, false};
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::General, stream);
|
||||
copy(arr, arr_copy, CopyType::General, stream);
|
||||
return {arr_copy, true};
|
||||
}
|
||||
};
|
||||
@@ -30,12 +30,12 @@ void AllReduce::eval_cpu(
|
||||
if (in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
return in;
|
||||
} else {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_cpu(in, arr_copy, CopyType::General, s);
|
||||
copy(in, arr_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(arr_copy);
|
||||
return arr_copy;
|
||||
}
|
||||
@@ -46,15 +46,8 @@ void AllReduce::eval_cpu(
|
||||
case Sum:
|
||||
distributed::detail::all_sum(group(), in, outputs[0], stream());
|
||||
break;
|
||||
case Max:
|
||||
distributed::detail::all_max(group(), in, outputs[0], stream());
|
||||
break;
|
||||
case Min:
|
||||
distributed::detail::all_min(group(), in, outputs[0], stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Only all reduce sum, min and max are supported for now");
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,7 +58,7 @@ void AllGather::eval_cpu(
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto [in, copied] = ensure_row_contiguous(inputs[0], stream());
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
distributed::detail::all_gather(group(), in, outputs[0], stream());
|
||||
if (copied) {
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
@@ -94,7 +87,7 @@ void Recv::eval_cpu(
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
}
|
||||
|
||||
|
@@ -1,174 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void eig_impl(
|
||||
array& a,
|
||||
array& vectors,
|
||||
array& values,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using OT = std::complex<T>;
|
||||
auto a_ptr = a.data<T>();
|
||||
auto eig_ptr = values.data<OT>();
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_output_array(values);
|
||||
OT* vec_ptr = nullptr;
|
||||
if (compute_eigenvectors) {
|
||||
encoder.set_output_array(vectors);
|
||||
vec_ptr = vectors.data<OT>();
|
||||
}
|
||||
encoder.dispatch([a_ptr,
|
||||
vec_ptr,
|
||||
eig_ptr,
|
||||
compute_eigenvectors,
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
char jobr = 'N';
|
||||
char jobl = compute_eigenvectors ? 'V' : 'N';
|
||||
int n_vecs_r = 1;
|
||||
int n_vecs_l = compute_eigenvectors ? N : 1;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
&work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
}
|
||||
|
||||
auto eig_tmp_data = array::Data{allocator::malloc(sizeof(T) * N * 2)};
|
||||
auto vec_tmp_data =
|
||||
array::Data{allocator::malloc(vec_ptr ? sizeof(T) * N * N * 2 : 0)};
|
||||
auto eig_tmp = static_cast<T*>(eig_tmp_data.buffer.raw_ptr());
|
||||
auto vec_tmp = static_cast<T*>(vec_tmp_data.buffer.raw_ptr());
|
||||
auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
geev<T>(
|
||||
&jobl,
|
||||
&jobr,
|
||||
&N,
|
||||
a_ptr,
|
||||
&N,
|
||||
eig_tmp,
|
||||
eig_tmp + N,
|
||||
vec_tmp,
|
||||
&n_vecs_l,
|
||||
nullptr,
|
||||
&n_vecs_r,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
eig_ptr[i] = {eig_tmp[i], eig_tmp[N + i]};
|
||||
}
|
||||
if (vec_ptr) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (eig_ptr[i].imag() != 0) {
|
||||
// This vector and the next are a pair
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {
|
||||
vec_tmp[i * N + j], -vec_tmp[(i + 1) * N + j]};
|
||||
vec_ptr[(i + 1) * N + j] = {
|
||||
vec_tmp[i * N + j], vec_tmp[(i + 1) * N + j]};
|
||||
}
|
||||
i += 1;
|
||||
} else {
|
||||
for (int j = 0; j < N; ++j) {
|
||||
vec_ptr[i * N + j] = {vec_tmp[i * N + j], 0};
|
||||
}
|
||||
}
|
||||
}
|
||||
vec_ptr += N * N;
|
||||
}
|
||||
a_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eig::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
});
|
||||
encoder.add_temporary(a);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Eig::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
const auto& a = inputs[0];
|
||||
auto& values = outputs[0];
|
||||
|
||||
auto vectors = compute_eigenvectors_
|
||||
? outputs[1]
|
||||
: array(a.shape(), complex64, nullptr, {});
|
||||
|
||||
auto a_copy = array(a.shape(), a.dtype(), nullptr, {});
|
||||
copy_cpu(
|
||||
a,
|
||||
a_copy,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
|
||||
if (compute_eigenvectors_) {
|
||||
// Set the strides and flags so the eigenvectors
|
||||
// are in the columns of the output
|
||||
auto flags = vectors.flags();
|
||||
auto strides = vectors.strides();
|
||||
auto ndim = a.ndim();
|
||||
std::swap(strides[ndim - 1], strides[ndim - 2]);
|
||||
|
||||
if (a.size() > 1) {
|
||||
flags.row_contiguous = false;
|
||||
if (ndim > 2) {
|
||||
flags.col_contiguous = false;
|
||||
} else {
|
||||
flags.col_contiguous = true;
|
||||
}
|
||||
}
|
||||
vectors.set_data(
|
||||
allocator::malloc(vectors.nbytes()), vectors.size(), strides, flags);
|
||||
}
|
||||
switch (a.dtype()) {
|
||||
case float32:
|
||||
eig_impl<float>(a_copy, vectors, values, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("[Eig::eval_cpu] only supports float32.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -12,133 +12,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, class Enable = void>
|
||||
struct EighWork {};
|
||||
|
||||
template <typename T>
|
||||
struct EighWork<
|
||||
T,
|
||||
typename std::enable_if<std::is_floating_point<T>::value>::type> {
|
||||
using R = T;
|
||||
|
||||
char jobz;
|
||||
char uplo;
|
||||
int N;
|
||||
int lwork;
|
||||
int liwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EighWork(char jobz_, char uplo_, int N_)
|
||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), liwork(-1) {
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||
}
|
||||
|
||||
void run(T* vectors, T* values) {
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vectors,
|
||||
&N,
|
||||
values,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(buffers[1].buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct EighWork<std::complex<float>> {
|
||||
using T = std::complex<float>;
|
||||
using R = float;
|
||||
|
||||
char jobz;
|
||||
char uplo;
|
||||
int N;
|
||||
int lwork;
|
||||
int lrwork;
|
||||
int liwork;
|
||||
int info;
|
||||
std::vector<array::Data> buffers;
|
||||
|
||||
EighWork(char jobz_, char uplo_, int N_)
|
||||
: jobz(jobz_), uplo(uplo_), N(N_), lwork(-1), lrwork(-1), liwork(-1) {
|
||||
T work;
|
||||
R rwork;
|
||||
int iwork;
|
||||
heevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&rwork,
|
||||
&lrwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work.real());
|
||||
lrwork = static_cast<int>(rwork);
|
||||
liwork = iwork;
|
||||
buffers.emplace_back(allocator::malloc(sizeof(T) * lwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(R) * lrwork));
|
||||
buffers.emplace_back(allocator::malloc(sizeof(int) * liwork));
|
||||
}
|
||||
|
||||
void run(T* vectors, R* values) {
|
||||
heevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vectors,
|
||||
&N,
|
||||
values,
|
||||
static_cast<T*>(buffers[0].buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<R*>(buffers[1].buffer.raw_ptr()),
|
||||
&lrwork,
|
||||
static_cast<int*>(buffers[2].buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
if (jobz == 'V') {
|
||||
// We have pre-transposed the vectors but we also must conjugate them
|
||||
// when they are complex.
|
||||
//
|
||||
// We could vectorize this but it is so fast in comparison to heevd that
|
||||
// it doesn't really matter.
|
||||
for (int i = 0; i < N; i++) {
|
||||
for (int j = 0; j < N; j++) {
|
||||
*vectors = std::conj(*vectors);
|
||||
vectors++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void eigh_impl(
|
||||
array& vectors,
|
||||
@@ -146,10 +19,8 @@ void eigh_impl(
|
||||
const std::string& uplo,
|
||||
bool compute_eigenvectors,
|
||||
Stream stream) {
|
||||
using R = typename EighWork<T>::R;
|
||||
|
||||
auto vec_ptr = vectors.data<T>();
|
||||
auto eig_ptr = values.data<R>();
|
||||
auto eig_ptr = values.data<T>();
|
||||
char jobz = compute_eigenvectors ? 'V' : 'N';
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
@@ -162,17 +33,50 @@ void eigh_impl(
|
||||
N = vectors.shape(-1),
|
||||
size = vectors.size()]() mutable {
|
||||
// Work query
|
||||
EighWork<T> work(jobz, uplo, N);
|
||||
int lwork = -1;
|
||||
int liwork = -1;
|
||||
int info;
|
||||
{
|
||||
T work;
|
||||
int iwork;
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
nullptr,
|
||||
&N,
|
||||
nullptr,
|
||||
&work,
|
||||
&lwork,
|
||||
&iwork,
|
||||
&liwork,
|
||||
&info);
|
||||
lwork = static_cast<int>(work);
|
||||
liwork = iwork;
|
||||
}
|
||||
|
||||
// Work loop
|
||||
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
auto iwork_buf =
|
||||
array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
|
||||
for (size_t i = 0; i < size / (N * N); ++i) {
|
||||
work.run(vec_ptr, eig_ptr);
|
||||
syevd<T>(
|
||||
&jobz,
|
||||
&uplo,
|
||||
&N,
|
||||
vec_ptr,
|
||||
&N,
|
||||
eig_ptr,
|
||||
static_cast<T*>(work_buf.buffer.raw_ptr()),
|
||||
&lwork,
|
||||
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
|
||||
&liwork,
|
||||
&info);
|
||||
vec_ptr += N * N;
|
||||
eig_ptr += N;
|
||||
if (work.info != 0) {
|
||||
if (info != 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
|
||||
<< work.info;
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
@@ -194,9 +98,9 @@ void Eigh::eval_cpu(
|
||||
? outputs[1]
|
||||
: array(a.shape(), a.dtype(), nullptr, {});
|
||||
|
||||
values.set_data(allocator::malloc(values.nbytes()));
|
||||
values.set_data(allocator::malloc_or_wait(values.nbytes()));
|
||||
|
||||
copy_cpu(
|
||||
copy(
|
||||
a,
|
||||
vectors,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
@@ -228,10 +132,6 @@ void Eigh::eval_cpu(
|
||||
eigh_impl<double>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
eigh_impl<std::complex<float>>(
|
||||
vectors, values, uplo_, compute_eigenvectors_, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[Eigh::eval_cpu] only supports float32 or float64.");
|
||||
|
@@ -9,9 +9,6 @@
|
||||
|
||||
namespace mlx::core::cpu {
|
||||
|
||||
// Number of dispatches per scheduler task
|
||||
constexpr int DISPATCHES_PER_TASK = 10;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(Stream stream) : stream_(stream) {}
|
||||
|
||||
@@ -42,24 +39,13 @@ struct CommandEncoder {
|
||||
|
||||
template <class F, class... Args>
|
||||
void dispatch(F&& f, Args&&... args) {
|
||||
num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
|
||||
auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
|
||||
if (num_ops_ == 0) {
|
||||
scheduler::notify_new_task(stream_);
|
||||
auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
|
||||
task();
|
||||
scheduler::notify_task_completion(s);
|
||||
};
|
||||
scheduler::enqueue(stream_, std::move(task_wrap));
|
||||
} else {
|
||||
scheduler::enqueue(stream_, std::move(task));
|
||||
}
|
||||
scheduler::enqueue(stream_, std::move(task));
|
||||
}
|
||||
|
||||
private:
|
||||
Stream stream_;
|
||||
std::vector<array> temporaries_;
|
||||
int num_ops_{0};
|
||||
};
|
||||
|
||||
CommandEncoder& get_command_encoder(Stream stream);
|
||||
|
@@ -33,8 +33,12 @@ void eval(array& arr) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(s);
|
||||
encoder.dispatch([buffers = std::move(buffers),
|
||||
temps = std::move(encoder.temporaries())]() {});
|
||||
scheduler::notify_new_task(s);
|
||||
encoder.dispatch([s,
|
||||
buffers = std::move(buffers),
|
||||
temps = std::move(encoder.temporaries())]() {
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cpu
|
||||
|
@@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
s *= out.itemsize();
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
std::vector<size_t> shape;
|
||||
if (out.dtype() == float32) {
|
||||
|
27
mlx/backend/cpu/gemms/no_bf16.cpp
Normal file
27
mlx/backend/cpu/gemms/no_bf16.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const bfloat16_t*,
|
||||
const bfloat16_t*,
|
||||
bfloat16_t*,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float,
|
||||
size_t,
|
||||
const Shape&,
|
||||
const Strides&,
|
||||
const Shape&,
|
||||
const Strides&) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] bfloat16 not supported.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
27
mlx/backend/cpu/gemms/no_fp16.cpp
Normal file
27
mlx/backend/cpu/gemms/no_fp16.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const float16_t*,
|
||||
const float16_t*,
|
||||
float16_t*,
|
||||
bool,
|
||||
bool,
|
||||
size_t,
|
||||
size_t,
|
||||
size_t,
|
||||
float,
|
||||
float,
|
||||
size_t,
|
||||
const Shape&,
|
||||
const Strides&,
|
||||
const Shape&,
|
||||
const Strides&) {
|
||||
throw std::runtime_error("[Matmul::eval_cpu] float16 not supported.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,45 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<bfloat16_t>(
|
||||
const bfloat16_t* a,
|
||||
const bfloat16_t* b,
|
||||
bfloat16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
simd_gemm<bfloat16_t, float>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
out + M * N * i,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,45 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/gemms/simd_gemm.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <>
|
||||
void matmul<float16_t>(
|
||||
const float16_t* a,
|
||||
const float16_t* b,
|
||||
float16_t* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
size_t ldc,
|
||||
float alpha,
|
||||
float beta,
|
||||
size_t batch_size,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides) {
|
||||
auto ndim = a_shape.size();
|
||||
size_t M = a_shape[ndim - 2];
|
||||
size_t N = b_shape[ndim - 1];
|
||||
size_t K = a_shape[ndim - 1];
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
simd_gemm<float16_t, float>(
|
||||
a + elem_to_loc(M * K * i, a_shape, a_strides),
|
||||
b + elem_to_loc(K * N * i, b_shape, b_strides),
|
||||
out + M * N * i,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
beta);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,139 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline int ceildiv(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <int block_size, typename T, typename AccT>
|
||||
void load_block(
|
||||
const T* in,
|
||||
AccT* out,
|
||||
int M,
|
||||
int N,
|
||||
int i,
|
||||
int j,
|
||||
bool transpose) {
|
||||
if (transpose) {
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
out[jj * block_size + ii] =
|
||||
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
out[ii * block_size + jj] =
|
||||
in[(i * block_size + ii) * N + j * block_size + jj];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void simd_gemm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
T* c,
|
||||
bool a_trans,
|
||||
bool b_trans,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
float alpha,
|
||||
float beta) {
|
||||
constexpr int block_size = 16;
|
||||
constexpr int simd_size = simd::max_size<AccT>;
|
||||
static_assert(
|
||||
(block_size % simd_size) == 0,
|
||||
"Block size must be divisible by SIMD size");
|
||||
|
||||
int last_k_block_size = K - block_size * (K / block_size);
|
||||
int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
|
||||
for (int i = 0; i < ceildiv(M, block_size); i++) {
|
||||
for (int j = 0; j < ceildiv(N, block_size); j++) {
|
||||
AccT c_block[block_size * block_size] = {0.0};
|
||||
AccT a_block[block_size * block_size];
|
||||
AccT b_block[block_size * block_size];
|
||||
|
||||
int k = 0;
|
||||
for (; k < K / block_size; k++) {
|
||||
// Load a and b blocks
|
||||
if (a_trans) {
|
||||
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||
} else {
|
||||
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||
}
|
||||
if (b_trans) {
|
||||
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||
} else {
|
||||
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||
}
|
||||
|
||||
// Multiply and accumulate
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
for (int kk = 0; kk < block_size; kk += simd_size) {
|
||||
auto av =
|
||||
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||
auto bv =
|
||||
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_k_block_size) {
|
||||
// Load a and b blocks
|
||||
if (a_trans) {
|
||||
load_block<block_size>(a, a_block, K, M, k, i, true);
|
||||
} else {
|
||||
load_block<block_size>(a, a_block, M, K, i, k, false);
|
||||
}
|
||||
if (b_trans) {
|
||||
load_block<block_size>(b, b_block, N, K, j, k, false);
|
||||
} else {
|
||||
load_block<block_size>(b, b_block, K, N, k, j, true);
|
||||
}
|
||||
|
||||
// Multiply and accumulate
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
int kk = 0;
|
||||
for (; kk < last_k_simd_block; kk += simd_size) {
|
||||
auto av =
|
||||
simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
|
||||
auto bv =
|
||||
simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
|
||||
c_block[ii * block_size + jj] += simd::sum(av * bv);
|
||||
}
|
||||
for (; kk < last_k_block_size; ++kk) {
|
||||
c_block[ii * block_size + jj] +=
|
||||
a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store
|
||||
for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
|
||||
for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
|
||||
auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
|
||||
if (beta != 0) {
|
||||
c[c_idx] = static_cast<T>(
|
||||
alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
|
||||
} else {
|
||||
c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -96,7 +96,7 @@ void Hadamard::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy_cpu(
|
||||
copy(
|
||||
in,
|
||||
out,
|
||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,7 @@ namespace mlx::core {
|
||||
template <typename T>
|
||||
void general_inv(T* inv, int N) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)};
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
getrf<T>(
|
||||
/* m = */ &N,
|
||||
@@ -49,7 +49,7 @@ void general_inv(T* inv, int N) {
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
getri<T>(
|
||||
@@ -115,7 +115,7 @@ void inverse_impl(
|
||||
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
|
||||
|
||||
// The inverse is computed in place, so just copy the input to the output.
|
||||
copy_cpu(
|
||||
copy(
|
||||
a,
|
||||
inv,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
|
@@ -2,14 +2,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// Required for Visual Studio.
|
||||
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
|
||||
#ifdef _MSC_VER
|
||||
#include <complex>
|
||||
#define LAPACK_COMPLEX_CUSTOM
|
||||
#define lapack_complex_float std::complex<float>
|
||||
#define lapack_complex_double std::complex<double>
|
||||
#define lapack_complex_float_real(z) ((z).real())
|
||||
#define lapack_complex_float_imag(z) ((z).imag())
|
||||
#define lapack_complex_double_real(z) ((z).real())
|
||||
#define lapack_complex_double_imag(z) ((z).imag())
|
||||
#endif
|
||||
|
||||
#ifdef MLX_USE_ACCELERATE
|
||||
#include <Accelerate/Accelerate.h>
|
||||
@@ -32,7 +32,7 @@
|
||||
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_LAPACK_REAL(FUNC) \
|
||||
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, float>) { \
|
||||
@@ -42,24 +42,11 @@
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_REAL(geqrf)
|
||||
INSTANTIATE_LAPACK_REAL(orgqr)
|
||||
INSTANTIATE_LAPACK_REAL(syevd)
|
||||
INSTANTIATE_LAPACK_REAL(geev)
|
||||
INSTANTIATE_LAPACK_REAL(potrf)
|
||||
INSTANTIATE_LAPACK_REAL(gesvdx)
|
||||
INSTANTIATE_LAPACK_REAL(getrf)
|
||||
INSTANTIATE_LAPACK_REAL(getri)
|
||||
INSTANTIATE_LAPACK_REAL(trtri)
|
||||
|
||||
#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
|
||||
template <typename T, typename... Args> \
|
||||
void FUNC(Args... args) { \
|
||||
if constexpr (std::is_same_v<T, std::complex<float>>) { \
|
||||
MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
|
||||
} else if constexpr (std::is_same_v<T, std::complex<double>>) { \
|
||||
MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_LAPACK_COMPLEX(heevd)
|
||||
INSTANTIATE_LAPACK_TYPES(geqrf)
|
||||
INSTANTIATE_LAPACK_TYPES(orgqr)
|
||||
INSTANTIATE_LAPACK_TYPES(syevd)
|
||||
INSTANTIATE_LAPACK_TYPES(potrf)
|
||||
INSTANTIATE_LAPACK_TYPES(gesvdx)
|
||||
INSTANTIATE_LAPACK_TYPES(getrf)
|
||||
INSTANTIATE_LAPACK_TYPES(getri)
|
||||
INSTANTIATE_LAPACK_TYPES(trtri)
|
||||
|
@@ -1,140 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/types/limits.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace mlx::core::simd;
|
||||
|
||||
template <typename T, typename AccT>
|
||||
void logsumexp(const array& in, array& out, Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
const T* in_ptr = in.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
int M = in.shape().back();
|
||||
int L = in.data_size() / M;
|
||||
|
||||
encoder.dispatch([in_ptr, out_ptr, M, L]() mutable {
|
||||
constexpr int N = std::min(max_size<AccT>, max_size<T>);
|
||||
|
||||
const T* current_in_ptr;
|
||||
|
||||
for (int i = 0; i < L; i++, in_ptr += M, out_ptr += 1) {
|
||||
// Find the maximum
|
||||
current_in_ptr = in_ptr;
|
||||
Simd<AccT, N> vmaximum(-numeric_limits<AccT>::infinity());
|
||||
size_t s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vals = load<T, N>(current_in_ptr);
|
||||
vmaximum = maximum(vals, vmaximum);
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
|
||||
AccT maximum = max(vmaximum);
|
||||
while (s-- > 0) {
|
||||
maximum = std::max(maximum, static_cast<AccT>(*current_in_ptr));
|
||||
current_in_ptr++;
|
||||
}
|
||||
|
||||
// Compute the normalizer and the exponentials
|
||||
Simd<AccT, N> vnormalizer(0.0);
|
||||
current_in_ptr = in_ptr;
|
||||
s = M;
|
||||
while (s >= N) {
|
||||
Simd<AccT, N> vexp = load<T, N>(current_in_ptr);
|
||||
vexp = exp(vexp - maximum);
|
||||
vnormalizer = vnormalizer + vexp;
|
||||
current_in_ptr += N;
|
||||
s -= N;
|
||||
}
|
||||
AccT normalizer = sum(vnormalizer);
|
||||
while (s-- > 0) {
|
||||
AccT _exp = std::exp(*current_in_ptr - maximum);
|
||||
normalizer += _exp;
|
||||
current_in_ptr++;
|
||||
}
|
||||
// Normalize
|
||||
*out_ptr = std::isinf(maximum)
|
||||
? static_cast<T>(maximum)
|
||||
: static_cast<T>(std::log(normalizer) + maximum);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LogSumExp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto s = stream();
|
||||
auto& encoder = cpu::get_command_encoder(s);
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_cpu(x, x_copy, CopyType::General, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto in = ensure_contiguous(inputs[0]);
|
||||
if (in.flags().row_contiguous) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
auto n = in.shape(-1);
|
||||
auto flags = in.flags();
|
||||
auto strides = in.strides();
|
||||
for (auto& s : strides) {
|
||||
s /= n;
|
||||
}
|
||||
bool col_contig = strides[0] == 1;
|
||||
for (int i = 1; col_contig && i < strides.size(); ++i) {
|
||||
col_contig &=
|
||||
(out.shape(i) == 1 || strides[i - 1] == out.shape(i) * strides[i]);
|
||||
}
|
||||
flags.col_contiguous = col_contig;
|
||||
out.set_data(
|
||||
allocator::malloc(in.nbytes() / n),
|
||||
in.data_size() / n,
|
||||
std::move(strides),
|
||||
flags);
|
||||
}
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
logsumexp<float, float>(in, out, stream());
|
||||
break;
|
||||
case float16:
|
||||
logsumexp<float16_t, float>(in, out, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
logsumexp<bfloat16_t, float>(in, out, stream());
|
||||
break;
|
||||
case float64:
|
||||
logsumexp<double, double>(in, out, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[logsumexp] only supports floating point types");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -30,8 +30,9 @@ void luf_impl(
|
||||
auto strides = lu.strides();
|
||||
strides[ndim - 1] = M;
|
||||
strides[ndim - 2] = 1;
|
||||
lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
copy_cpu_inplace(
|
||||
lu.set_data(
|
||||
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
copy_inplace(
|
||||
a,
|
||||
lu,
|
||||
a.shape(),
|
||||
@@ -43,8 +44,8 @@ void luf_impl(
|
||||
stream);
|
||||
|
||||
auto a_ptr = lu.data<T>();
|
||||
pivots.set_data(allocator::malloc(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc(row_indices.nbytes()));
|
||||
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||
auto pivots_ptr = pivots.data<uint32_t>();
|
||||
auto row_indices_ptr = row_indices.data<uint32_t>();
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
|
@@ -6,7 +6,6 @@
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/gemm.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@@ -53,58 +52,6 @@ inline void mask_matrix(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void segmented_mm(
|
||||
const T* a,
|
||||
const T* b,
|
||||
const uint32_t* segments,
|
||||
T* out,
|
||||
bool a_transposed,
|
||||
bool b_transposed,
|
||||
size_t lda,
|
||||
size_t ldb,
|
||||
const Shape& a_shape,
|
||||
const Strides& a_strides,
|
||||
const Shape& b_shape,
|
||||
const Strides& b_strides,
|
||||
size_t num_segments,
|
||||
const Shape& segments_shape,
|
||||
const Strides& segments_strides) {
|
||||
int ndim = a_shape.size();
|
||||
Shape a_copy = a_shape;
|
||||
Shape b_copy = b_shape;
|
||||
int32_t M = a_copy[ndim - 2];
|
||||
int32_t N = b_copy[ndim - 1];
|
||||
for (int i = 0; i < num_segments; i++) {
|
||||
uint32_t k_start =
|
||||
segments[elem_to_loc(2 * i, segments_shape, segments_strides)];
|
||||
uint32_t k_end =
|
||||
segments[elem_to_loc(2 * i + 1, segments_shape, segments_strides)];
|
||||
if (k_end <= k_start) {
|
||||
std::fill_n(out + i * M * N, M * N, T(0));
|
||||
continue;
|
||||
}
|
||||
a_copy[ndim - 1] = k_end - k_start;
|
||||
b_copy[ndim - 2] = k_end - k_start;
|
||||
matmul<T>(
|
||||
a + k_start * a_strides[ndim - 1],
|
||||
b + k_start * b_strides[ndim - 2],
|
||||
out + i * M * N,
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
N,
|
||||
1.0,
|
||||
0.0,
|
||||
1,
|
||||
a_copy,
|
||||
a_strides,
|
||||
b_copy,
|
||||
b_strides);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -112,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[BlockMaskedMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
@@ -124,20 +71,20 @@ void BlockMaskedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::Vector, s);
|
||||
copy(arr, arr_copy, CopyType::Vector, s);
|
||||
return std::make_tuple(false, stx, arr_copy, true);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr, false);
|
||||
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::Vector, s);
|
||||
copy(arr, arr_copy, CopyType::Vector, s);
|
||||
return std::make_tuple(true, sty, arr_copy, true);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::General, s);
|
||||
copy(arr, arr_copy, CopyType::General, s);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy, true);
|
||||
}
|
||||
@@ -371,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[GatherMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
@@ -386,7 +333,7 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
int64_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, temps.back());
|
||||
}
|
||||
@@ -490,121 +437,4 @@ void GatherMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void SegmentedMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
auto check_transpose = [&s, &encoder](const array& x) {
|
||||
auto stx = x.strides()[x.ndim() - 2];
|
||||
auto sty = x.strides()[x.ndim() - 1];
|
||||
if (stx == x.shape(-1) && sty == 1) {
|
||||
return std::make_tuple(false, stx, x);
|
||||
} else if (stx == 1 && sty == x.shape(-2)) {
|
||||
return std::make_tuple(true, sty, x);
|
||||
} else {
|
||||
array xc(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_cpu(x, xc, CopyType::General, s);
|
||||
encoder.add_temporary(xc);
|
||||
int64_t stx = x.shape(-1);
|
||||
return std::make_tuple(false, stx, xc);
|
||||
}
|
||||
};
|
||||
|
||||
auto [a_transposed, lda, a] = check_transpose(inputs[0]);
|
||||
auto [b_transposed, ldb, b] = check_transpose(inputs[1]);
|
||||
auto& segments = inputs[2];
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(segments);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
segments = array::unsafe_weak_copy(segments),
|
||||
out_ptr = out.data<void>(),
|
||||
a_transposed = a_transposed,
|
||||
b_transposed = b_transposed,
|
||||
lda = lda,
|
||||
ldb = ldb]() {
|
||||
switch (a.dtype()) {
|
||||
case float64:
|
||||
segmented_mm<double>(
|
||||
a.data<double>(),
|
||||
b.data<double>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<double*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case float32:
|
||||
segmented_mm<float>(
|
||||
a.data<float>(),
|
||||
b.data<float>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<float*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case float16:
|
||||
segmented_mm<float16_t>(
|
||||
a.data<float16_t>(),
|
||||
b.data<float16_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<float16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
case bfloat16:
|
||||
segmented_mm<bfloat16_t>(
|
||||
a.data<bfloat16_t>(),
|
||||
b.data<bfloat16_t>(),
|
||||
segments.data<uint32_t>(),
|
||||
static_cast<bfloat16_t*>(out_ptr),
|
||||
a_transposed,
|
||||
b_transposed,
|
||||
lda,
|
||||
ldb,
|
||||
a.shape(),
|
||||
a.strides(),
|
||||
b.shape(),
|
||||
b.strides(),
|
||||
segments.size() / 2,
|
||||
segments.shape(),
|
||||
segments.strides());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
"Segmented mm supports only real float types.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -81,7 +81,7 @@ void matmul_general(
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy_cpu(arr, temps.back(), CopyType::General, stream);
|
||||
copy(arr, temps.back(), CopyType::General, stream);
|
||||
stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, temps.back());
|
||||
}
|
||||
@@ -115,7 +115,7 @@ void matmul_general(
|
||||
}
|
||||
|
||||
void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
@@ -132,20 +132,14 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error(
|
||||
"[AddMM::eval_cpu] Currently only supports float32.");
|
||||
}
|
||||
if (out.size() == 0) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
return;
|
||||
}
|
||||
|
||||
// Fill output with C
|
||||
auto& c = inputs[2];
|
||||
CopyType ctype = c.data_size() == 1
|
||||
? CopyType::Scalar
|
||||
: (c.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
copy_cpu(c, out, ctype, stream());
|
||||
if (inputs[0].shape(-1) == 0) {
|
||||
return;
|
||||
}
|
||||
copy(c, out, ctype, stream());
|
||||
|
||||
matmul_general(inputs[0], inputs[1], out, stream(), alpha_, beta_);
|
||||
}
|
||||
|
||||
|
@@ -21,8 +21,8 @@ namespace mlx::core {
|
||||
void reshape(const array& in, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_cpu_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
copy_inplace(in, out, CopyType::General, out.primitive().stream());
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
@@ -39,7 +39,7 @@ static std::pair<array, bool> compute_dynamic_offset(
|
||||
if (donate) {
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc(offset.itemsize()));
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
}
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
@@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Arange::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 0);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
throw std::runtime_error("Bool type unsupported for arange.");
|
||||
@@ -175,7 +175,7 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
copy(in, out, ctype, stream());
|
||||
}
|
||||
|
||||
void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin());
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto strides = out.strides();
|
||||
auto flags = out.flags();
|
||||
@@ -198,20 +198,18 @@ void Concatenate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
size_t data_offset = strides[axis_] * sizes[i];
|
||||
out_slice.copy_shared_buffer(
|
||||
out, strides, flags, out_slice.size(), data_offset);
|
||||
copy_cpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
constexpr size_t extra_bytes = 16384;
|
||||
if (in.buffer_size() <= out.nbytes() + extra_bytes &&
|
||||
(in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy_cpu(in, out, CopyType::General, stream());
|
||||
copy(in, out, CopyType::General, stream());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,7 +233,7 @@ void Full::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
copy(in, out, ctype, stream());
|
||||
}
|
||||
|
||||
void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -251,7 +249,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(val.dtype() == in.dtype() && in.dtype() == out.dtype());
|
||||
|
||||
// Fill output with val
|
||||
copy_cpu(val, out, CopyType::Scalar, stream());
|
||||
copy(val, out, CopyType::Scalar, stream());
|
||||
|
||||
// Find offset for start of input values
|
||||
size_t data_offset = 0;
|
||||
@@ -266,7 +264,7 @@ void Pad::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
out, out.strides(), out.flags(), out_slice.size(), data_offset);
|
||||
|
||||
// Copy input values into the slice
|
||||
copy_cpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||
copy_inplace(in, out_slice, CopyType::GeneralGeneral, stream());
|
||||
}
|
||||
|
||||
void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -278,7 +276,7 @@ void RandomBits::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
size_t elems_per_key = out.size() / num_keys;
|
||||
size_t bytes_per_key = out.itemsize() * elems_per_key;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto kptr = inputs[0].data<uint32_t>();
|
||||
auto cptr = out.data<char>();
|
||||
@@ -337,10 +335,10 @@ void DynamicSlice::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return;
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto [in_offset, donated] =
|
||||
compute_dynamic_offset(inputs[1], in.strides(), axes_, stream());
|
||||
copy_cpu_inplace(
|
||||
copy_inplace(
|
||||
/* const array& src = */ in,
|
||||
/* array& dst = */ out,
|
||||
/* const Shape& data_shape = */ out.shape(),
|
||||
@@ -372,11 +370,11 @@ void DynamicSliceUpdate::eval_cpu(
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
auto [out_offset, donated] =
|
||||
compute_dynamic_offset(inputs[2], out.strides(), axes_, stream());
|
||||
copy_cpu_inplace(
|
||||
copy_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
@@ -412,14 +410,14 @@ void SliceUpdate::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto ctype = in.flags().contiguous && in.size() == in.data_size()
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_cpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [data_offset, out_strides] =
|
||||
prepare_slice(out, start_indices_, strides_);
|
||||
|
||||
// Do copy
|
||||
copy_cpu_inplace(
|
||||
copy_inplace(
|
||||
/* const array& src = */ upd,
|
||||
/* array& dst = */ out,
|
||||
/* const std::vector<int>& data_shape = */ upd.shape(),
|
||||
@@ -452,13 +450,13 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
} else {
|
||||
auto tmp = array(
|
||||
in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc(tmp.nbytes()));
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
if (in.dtype() == bool_) {
|
||||
auto in_tmp = array(in.shape(), uint8, nullptr, {});
|
||||
in_tmp.copy_shared_buffer(in);
|
||||
copy_cpu_inplace(in_tmp, tmp, CopyType::General, stream());
|
||||
copy_inplace(in_tmp, tmp, CopyType::General, stream());
|
||||
} else {
|
||||
copy_cpu_inplace(in, tmp, CopyType::General, stream());
|
||||
copy_inplace(in, tmp, CopyType::General, stream());
|
||||
}
|
||||
|
||||
auto flags = out.flags();
|
||||
|
@@ -25,11 +25,12 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
auto strides = in.strides();
|
||||
strides[in.ndim() - 2] = 1;
|
||||
strides[in.ndim() - 1] = M;
|
||||
in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags);
|
||||
copy_cpu_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||
in.set_data(
|
||||
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
|
||||
copy_inplace(a, in, CopyType::GeneralGeneral, stream);
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
q.set_data(allocator::malloc(q.nbytes()));
|
||||
r.set_data(allocator::malloc(r.nbytes()));
|
||||
q.set_data(allocator::malloc_or_wait(q.nbytes()));
|
||||
r.set_data(allocator::malloc_or_wait(r.nbytes()));
|
||||
|
||||
auto in_ptr = in.data<T>();
|
||||
auto r_ptr = r.data<T>();
|
||||
@@ -40,7 +41,8 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
encoder.set_output_array(r);
|
||||
encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() {
|
||||
int num_reflectors = std::min(M, N);
|
||||
auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors);
|
||||
auto tau =
|
||||
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||
|
||||
T optimal_work;
|
||||
int lwork = -1;
|
||||
@@ -51,7 +53,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
|
||||
// Update workspace size
|
||||
lwork = optimal_work;
|
||||
auto work = allocator::malloc(sizeof(T) * lwork);
|
||||
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
@@ -94,7 +96,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) {
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = optimal_work;
|
||||
work = allocator::malloc(sizeof(T) * lwork);
|
||||
work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
|
@@ -13,18 +13,9 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
inline constexpr short get_pack_factor(int bits, int wsize = 8) {
|
||||
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
||||
}
|
||||
|
||||
inline constexpr short get_bytes_per_pack(int bits, int wsize = 8) {
|
||||
auto power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
||||
}
|
||||
|
||||
template <typename T, int bits>
|
||||
void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
static_assert(bits == 3 || bits == 5 || bits == 6);
|
||||
assert(bits == 3 || bits == 6);
|
||||
if (bits == 3) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x7);
|
||||
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
|
||||
@@ -34,16 +25,6 @@ void extract_bits(const uint8_t* w_in, T* w_out) {
|
||||
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
|
||||
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
|
||||
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
|
||||
} else if (bits == 5) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x1f);
|
||||
w_out[1] = static_cast<T>(((w_in[0] & 0xe0) >> 5) + ((w_in[1] & 0x3) << 3));
|
||||
w_out[2] = static_cast<T>((w_in[1] & 0x7c) >> 2);
|
||||
w_out[3] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0xf) << 1));
|
||||
w_out[4] = static_cast<T>(((w_in[2] & 0xf0) >> 4) + ((w_in[3] & 0x1) << 4));
|
||||
w_out[5] = static_cast<T>((w_in[3] & 0x3e) >> 1);
|
||||
w_out[6] = static_cast<T>(((w_in[3] & 0xc0) >> 6) + ((w_in[4] & 0x7) << 2));
|
||||
w_out[7] = static_cast<T>((w_in[4] & 0xf8) >> 3);
|
||||
|
||||
} else if (bits == 6) {
|
||||
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
|
||||
w_out[1] =
|
||||
@@ -65,8 +46,8 @@ void _qmm(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -84,7 +65,7 @@ void _qmm(
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
for (int ng = 0; ng < packs_in_group; ng++) {
|
||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||
if (bits == 3 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -123,9 +104,8 @@ void _qmm_t(
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
|
||||
constexpr int pack_factor = get_pack_factor(bits, 8);
|
||||
constexpr int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
|
||||
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -141,7 +121,7 @@ void _qmm_t(
|
||||
T bias = *biases_local++;
|
||||
|
||||
for (int kw = 0; kw < packs_in_group; kw++) {
|
||||
if constexpr (bits == 3 || bits == 5 || bits == 6) {
|
||||
if (bits == 3 || bits == 6) {
|
||||
T wl[pack_factor];
|
||||
extract_bits<T, bits>(w_local, wl);
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -324,10 +304,6 @@ void _qmm_dispatch_typed(
|
||||
_qmm_dispatch_group<T, 4>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 5:
|
||||
_qmm_dispatch_group<T, 5>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
break;
|
||||
case 6:
|
||||
_qmm_dispatch_group<T, 6>(
|
||||
result, x, w, scales, biases, M, N, K, group_size, transposed_w);
|
||||
@@ -350,7 +326,8 @@ void _qmm_dispatch_typed(
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
@@ -358,25 +335,56 @@ void _qmm_dispatch_typed(
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
int batch_size = x.size() / (K * M);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
|
||||
biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
|
||||
encoder.dispatch([out_ptr,
|
||||
x_ptr,
|
||||
w_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
x_shape = x.shape(),
|
||||
x_strides = x.strides(),
|
||||
w_shape = w.shape(),
|
||||
w_strides = w.strides(),
|
||||
scales_shape = scales.shape(),
|
||||
scales_strides = scales.strides(),
|
||||
biases_shape = biases.shape(),
|
||||
biases_strides = biases.strides(),
|
||||
w_els,
|
||||
g_els,
|
||||
batch_size,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w] {
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides),
|
||||
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides),
|
||||
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
@@ -387,19 +395,20 @@ void _qmm_dispatch(
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
@@ -418,7 +427,8 @@ void _bs_qmm_dispatch_typed(
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
@@ -426,6 +436,15 @@ void _bs_qmm_dispatch_typed(
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
@@ -434,26 +453,53 @@ void _bs_qmm_dispatch_typed(
|
||||
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices.shape(), lhs_indices.strides())];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices.shape(), rhs_indices.strides())];
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
|
||||
scales_ptr +
|
||||
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
|
||||
biases_ptr +
|
||||
elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
encoder.dispatch([out_ptr,
|
||||
x_ptr,
|
||||
w_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
lhs_indices_ptr,
|
||||
rhs_indices_ptr,
|
||||
x_shape = x.shape(),
|
||||
x_strides = x.strides(),
|
||||
w_shape = w.shape(),
|
||||
w_strides = w.strides(),
|
||||
scales_shape = scales.shape(),
|
||||
scales_strides = scales.strides(),
|
||||
biases_shape = biases.shape(),
|
||||
biases_strides = biases.strides(),
|
||||
lhs_indices_shape = lhs_indices.shape(),
|
||||
lhs_indices_strides = lhs_indices.strides(),
|
||||
rhs_indices_shape = rhs_indices.shape(),
|
||||
rhs_indices_strides = rhs_indices.strides(),
|
||||
w_els,
|
||||
g_els,
|
||||
indices_size = lhs_indices.size(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w]() {
|
||||
for (int i = 0; i < indices_size; i++) {
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices_shape, lhs_indices_strides)];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices_shape, rhs_indices_strides)];
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides),
|
||||
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides),
|
||||
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void _bs_qmm_dispatch(
|
||||
@@ -466,7 +512,8 @@ void _bs_qmm_dispatch(
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w) {
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_bs_qmm_dispatch_typed<float>(
|
||||
@@ -479,7 +526,8 @@ void _bs_qmm_dispatch(
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
transposed_w,
|
||||
stream);
|
||||
break;
|
||||
case float16:
|
||||
_bs_qmm_dispatch_typed<float16_t>(
|
||||
@@ -492,7 +540,8 @@ void _bs_qmm_dispatch(
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
transposed_w,
|
||||
stream);
|
||||
break;
|
||||
case bfloat16:
|
||||
_bs_qmm_dispatch_typed<bfloat16_t>(
|
||||
@@ -505,7 +554,8 @@ void _bs_qmm_dispatch(
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
transposed_w,
|
||||
stream);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
@@ -529,7 +579,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return arr;
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
}
|
||||
};
|
||||
@@ -539,25 +589,11 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_qmm_dispatch(
|
||||
out, x, w, scales, biases, group_size_, bits_, transpose_, stream());
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -579,7 +615,7 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return arr;
|
||||
} else {
|
||||
temps.push_back(array(arr.shape(), arr.dtype(), nullptr, {}));
|
||||
copy_cpu(arr, temps.back(), CopyType::General, s);
|
||||
copy(arr, temps.back(), CopyType::General, s);
|
||||
return temps.back();
|
||||
}
|
||||
};
|
||||
@@ -589,39 +625,21 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
});
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_,
|
||||
stream());
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporaries(std::move(temps));
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
@@ -637,8 +655,9 @@ void quantize(
|
||||
float eps = 1e-7;
|
||||
|
||||
bool power_of_2_bits = is_power_of_2(bits);
|
||||
int el_per_int = get_pack_factor(bits, 32);
|
||||
int bytes_per_pack = get_bytes_per_pack(bits);
|
||||
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
|
||||
int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
int int_per_group = group_size * bytes_per_pack / el_per_int;
|
||||
size_t n_groups = w_size / group_size;
|
||||
|
||||
@@ -663,21 +682,15 @@ void quantize(
|
||||
}
|
||||
size_t out_idx = i * int_per_group;
|
||||
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
|
||||
uint64_t out_el = 0;
|
||||
uint32_t out_el = 0;
|
||||
for (int k = 0; k < el_per_int; ++k) {
|
||||
float w_el = w[w_idx + j * el_per_int + k];
|
||||
w_el = std::rint((w_el - bias) / scale);
|
||||
w_el = std::min(std::max(w_el, 0.0f), n_bins);
|
||||
out_el |= static_cast<uint64_t>(w_el) << (k * bits);
|
||||
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
|
||||
}
|
||||
if (power_of_2_bits) {
|
||||
out[out_idx + j] = out_el;
|
||||
} else if (bits == 5) {
|
||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
|
||||
out[out_idx + bytes_per_pack * j + 3] = (out_el & 0xff000000) >> 24;
|
||||
out[out_idx + bytes_per_pack * j + 4] = (out_el & 0xff00000000) >> 32;
|
||||
} else {
|
||||
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
|
||||
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
|
||||
@@ -696,13 +709,27 @@ void dispatch_quantize(
|
||||
array& scales,
|
||||
array& biases,
|
||||
int bits,
|
||||
int group_size) {
|
||||
int group_size,
|
||||
Stream stream) {
|
||||
auto w_ptr = w.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
quantize<T, U>(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([w_ptr,
|
||||
out_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
bits,
|
||||
group_size,
|
||||
w_size = w.size()]() {
|
||||
quantize<T, U>(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size);
|
||||
});
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
@@ -713,62 +740,50 @@ void fast::AffineQuantize::eval_cpu(
|
||||
return std::make_pair(arr, false);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy_cpu(arr, arr_copy, CopyType::General, s);
|
||||
copy(arr, arr_copy, CopyType::General, s);
|
||||
return std::make_pair(arr_copy, true);
|
||||
}
|
||||
};
|
||||
|
||||
auto [w, copied] = ensure_row_contiguous(inputs[0]);
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& scales = outputs[1];
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc(biases.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (copied) {
|
||||
encoder.add_temporary(w);
|
||||
}
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([w = array::unsafe_weak_copy(w),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_]() mutable {
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<float16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<bfloat16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<float, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
dispatch_quantize<float16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
});
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
dispatch_quantize<bfloat16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
dispatch_quantize<float, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
if (copied) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(w);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -140,23 +140,34 @@ void reduction_op(
|
||||
const array& x,
|
||||
array& out,
|
||||
const std::vector<int>& axes,
|
||||
U init) {
|
||||
U init,
|
||||
Stream stream) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
ReductionPlan plan = get_reduction_plan(x, axes);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto in_ptr = x.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, x.size(), Op{}, init);
|
||||
encoder.dispatch([in_ptr, out_ptr, init, size = x.size()]() {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, size, Op{}, init);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
|
||||
int reduction_size = plan.shape[0];
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++, in_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
encoder.dispatch(
|
||||
[in_ptr, out_ptr, init, reduction_size, size = out.size()]() mutable {
|
||||
for (int i = 0; i < size; i++, out_ptr++, in_ptr += reduction_size) {
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -167,29 +178,40 @@ void reduction_op(
|
||||
// Unrolling the following loop (and implementing it in order for
|
||||
// ContiguousReduce) should hold extra performance boost.
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(in_ptr + offset, out_ptr, reduction_size, Op{}, init);
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
contiguous_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, Op{}, init);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
Op{},
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
*out_ptr = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
contiguous_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
Op{},
|
||||
init);
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -198,12 +220,20 @@ void reduction_op(
|
||||
size_t reduction_stride = plan.strides.back();
|
||||
plan.shape.pop_back();
|
||||
plan.strides.pop_back();
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
in_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
size = out.size()]() mutable {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(in_ptr, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
in_ptr += reduction_stride * reduction_size;
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -215,49 +245,67 @@ void reduction_op(
|
||||
plan.strides.pop_back();
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
out_ptr += reduction_stride;
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
if (plan.shape.size() == 0) {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
strided_reduce(
|
||||
in_ptr + offset, out_ptr, reduction_size, reduction_stride, Op{});
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < size; i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
Op{});
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < out.size(); i += reduction_stride) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
std::fill_n(out_ptr, reduction_stride, init);
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
strided_reduce(
|
||||
in_ptr + offset + extra_offset,
|
||||
out_ptr,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
Op{});
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
out_ptr += reduction_stride;
|
||||
}
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (plan.type == GeneralReduce) {
|
||||
auto [shape, strides] = shapes_without_reduction_axes(x, axes);
|
||||
|
||||
for (int i = 0; i < out.size(); i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = Op{}(val, *(in_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
encoder.dispatch([in_ptr,
|
||||
out_ptr,
|
||||
init,
|
||||
size = out.size(),
|
||||
plan = std::move(plan),
|
||||
shape = std::move(shape),
|
||||
strides = std::move(strides)]() mutable {
|
||||
for (int i = 0; i < size; i++, out_ptr++) {
|
||||
int offset = elem_to_loc(i, shape, strides);
|
||||
U val = init;
|
||||
nd_loop(
|
||||
[&](int extra_offset) {
|
||||
val = Op{}(val, *(in_ptr + offset + extra_offset));
|
||||
},
|
||||
plan.shape,
|
||||
plan.strides);
|
||||
*out_ptr = val;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -325,15 +373,7 @@ struct MaxReduce {
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
return simd::max(x);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
if (simd::any(x != x)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::max(x);
|
||||
};
|
||||
};
|
||||
@@ -350,15 +390,7 @@ struct MinReduce {
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
return simd::min(x);
|
||||
};
|
||||
|
||||
template <int N, typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(simd::Simd<T, N> x) {
|
||||
if (simd::any(x != x)) {
|
||||
return static_cast<T>(NAN);
|
||||
}
|
||||
T operator()(simd::Simd<T, N> x) {
|
||||
return simd::min(x);
|
||||
};
|
||||
};
|
||||
@@ -402,11 +434,12 @@ void reduce_dispatch_and_or(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
if (rtype == Reduce::And) {
|
||||
reduction_op<InT, bool, AndReduce>(in, out, axes, true);
|
||||
reduction_op<InT, bool, AndReduce>(in, out, axes, true, stream);
|
||||
} else {
|
||||
reduction_op<InT, bool, OrReduce>(in, out, axes, false);
|
||||
reduction_op<InT, bool, OrReduce>(in, out, axes, false, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -415,18 +448,19 @@ void reduce_dispatch_sum_prod(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
if (rtype == Reduce::Sum) {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0);
|
||||
reduction_op<InT, int32_t, SumReduce>(in, out, axes, 0, stream);
|
||||
} else {
|
||||
reduction_op<InT, InT, SumReduce>(in, out, axes, 0);
|
||||
reduction_op<InT, InT, SumReduce>(in, out, axes, 0, stream);
|
||||
}
|
||||
} else {
|
||||
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
|
||||
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1);
|
||||
reduction_op<InT, int32_t, ProdReduce>(in, out, axes, 1, stream);
|
||||
} else {
|
||||
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1);
|
||||
reduction_op<InT, InT, ProdReduce>(in, out, axes, 1, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -436,144 +470,162 @@ void reduce_dispatch_min_max(
|
||||
const array& in,
|
||||
array& out,
|
||||
Reduce::ReduceType rtype,
|
||||
const std::vector<int>& axes) {
|
||||
const std::vector<int>& axes,
|
||||
Stream stream) {
|
||||
if (rtype == Reduce::Max) {
|
||||
auto init = Limits<InT>::min;
|
||||
reduction_op<InT, InT, MaxReduce>(in, out, axes, init);
|
||||
reduction_op<InT, InT, MaxReduce>(in, out, axes, init, stream);
|
||||
} else {
|
||||
auto init = Limits<InT>::max;
|
||||
reduction_op<InT, InT, MinReduce>(in, out, axes, init);
|
||||
reduction_op<InT, InT, MinReduce>(in, out, axes, init, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
reduce_type_ = reduce_type_,
|
||||
axes_ = axes_]() mutable {
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case float64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_sum_prod<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<int8_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<int16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_min_max<double>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
switch (reduce_type_) {
|
||||
case Reduce::And:
|
||||
case Reduce::Or: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_and_or<int8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
case float16:
|
||||
case bfloat16:
|
||||
reduce_dispatch_and_or<int16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
case int32:
|
||||
case float32:
|
||||
reduce_dispatch_and_or<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
case int64:
|
||||
case float64:
|
||||
case complex64:
|
||||
reduce_dispatch_and_or<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
});
|
||||
case Reduce::Sum:
|
||||
case Reduce::Prod: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case int8:
|
||||
reduce_dispatch_sum_prod<int8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
case uint16:
|
||||
reduce_dispatch_sum_prod<int16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int32:
|
||||
case uint32:
|
||||
reduce_dispatch_sum_prod<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int64:
|
||||
case uint64:
|
||||
reduce_dispatch_sum_prod<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_sum_prod<float16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_sum_prod<bfloat16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_sum_prod<float>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_sum_prod<double>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_sum_prod<complex64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Reduce::Max:
|
||||
case Reduce::Min: {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint8:
|
||||
reduce_dispatch_min_max<uint8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
reduce_dispatch_min_max<uint16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
reduce_dispatch_min_max<uint32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
reduce_dispatch_min_max<uint64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int8:
|
||||
reduce_dispatch_min_max<uint8_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int16:
|
||||
reduce_dispatch_min_max<uint16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int32:
|
||||
reduce_dispatch_min_max<int32_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case int64:
|
||||
reduce_dispatch_min_max<int64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float16:
|
||||
reduce_dispatch_min_max<float16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float32:
|
||||
reduce_dispatch_min_max<float>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case float64:
|
||||
reduce_dispatch_min_max<double>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
reduce_dispatch_min_max<bfloat16_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
reduce_dispatch_min_max<complex64_t>(
|
||||
in, out, reduce_type_, axes_, stream());
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -3,7 +3,6 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/binary_ops.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
#include "mlx/backend/cpu/simd/simd.h"
|
||||
@@ -161,29 +160,38 @@ void scan_op(
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const Op& op,
|
||||
U init) {
|
||||
U init,
|
||||
Stream stream) {
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
if (in.flags().row_contiguous) {
|
||||
if (in.strides()[axis] == 1) {
|
||||
contiguous_scan(
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
in.size() / in.shape(axis),
|
||||
in.shape(axis),
|
||||
reverse,
|
||||
inclusive,
|
||||
op,
|
||||
init);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<U>(),
|
||||
count = in.size() / in.shape(axis),
|
||||
stride = in.shape(axis),
|
||||
reverse,
|
||||
inclusive,
|
||||
op = std::move(op),
|
||||
init]() {
|
||||
contiguous_scan(
|
||||
in_ptr, out_ptr, count, stride, reverse, inclusive, op, init);
|
||||
});
|
||||
} else {
|
||||
strided_scan(
|
||||
in.data<T>(),
|
||||
out.data<U>(),
|
||||
in.size() / in.shape(axis) / in.strides()[axis],
|
||||
in.shape(axis),
|
||||
in.strides()[axis],
|
||||
reverse,
|
||||
inclusive,
|
||||
op,
|
||||
init);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<U>(),
|
||||
count = in.size() / in.shape(axis) / in.strides()[axis],
|
||||
size = in.shape(axis),
|
||||
stride = in.strides()[axis],
|
||||
reverse,
|
||||
inclusive,
|
||||
op = std::move(op),
|
||||
init]() {
|
||||
strided_scan(
|
||||
in_ptr, out_ptr, count, size, stride, reverse, inclusive, op, init);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Scan op supports only contiguous inputs");
|
||||
@@ -197,18 +205,19 @@ void scan_dispatch(
|
||||
array& out,
|
||||
int axis,
|
||||
bool reverse,
|
||||
bool inclusive) {
|
||||
bool inclusive,
|
||||
Stream stream) {
|
||||
switch (rtype) {
|
||||
case Scan::Sum: {
|
||||
auto op = [](U y, T x) { return y + x; };
|
||||
auto init = static_cast<U>(0);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
break;
|
||||
}
|
||||
case Scan::Prod: {
|
||||
auto op = [](U y, T x) { return y * x; };
|
||||
auto init = static_cast<U>(1);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
break;
|
||||
}
|
||||
case Scan::Min: {
|
||||
@@ -216,7 +225,7 @@ void scan_dispatch(
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
break;
|
||||
}
|
||||
case Scan::Max: {
|
||||
@@ -224,17 +233,7 @@ void scan_dispatch(
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
break;
|
||||
}
|
||||
case Scan::LogAddExp: {
|
||||
auto op = [](U a, T b) {
|
||||
return detail::LogAddExp{}(a, static_cast<U>(b));
|
||||
};
|
||||
auto init = (issubdtype(in.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init);
|
||||
scan_op<T, U>(in, out, axis, reverse, inclusive, op, init, stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -245,96 +244,88 @@ void scan_dispatch(
|
||||
void Scan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
|
||||
// Ensure contiguity
|
||||
auto in = inputs[0];
|
||||
bool copied = false;
|
||||
if (!in.flags().row_contiguous) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_cpu(in, arr_copy, CopyType::General, stream());
|
||||
copy(in, arr_copy, CopyType::General, stream());
|
||||
in = arr_copy;
|
||||
encoder.add_temporary(arr_copy);
|
||||
copied = true;
|
||||
}
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
reduce_type_ = reduce_type_,
|
||||
reverse_ = reverse_,
|
||||
inclusive_ = inclusive_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_: {
|
||||
// We could do a full dtype x dtype switch but this is the only case
|
||||
// where we accumulate in a different type, for now.
|
||||
//
|
||||
// TODO: If we add the option to accumulate floats in higher precision
|
||||
// floats perhaps we should add the full all-to-all dispatch.
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
|
||||
scan_dispatch<bool, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
} else {
|
||||
scan_dispatch<bool, bool>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
}
|
||||
break;
|
||||
switch (in.dtype()) {
|
||||
case bool_: {
|
||||
// We could do a full dtype x dtype switch but this is the only case
|
||||
// where we accumulate in a different type, for now.
|
||||
//
|
||||
// TODO: If we add the option to accumulate floats in higher precision
|
||||
// floats perhaps we should add the full all-to-all dispatch.
|
||||
if (reduce_type_ == Scan::Sum && out.dtype() == int32) {
|
||||
scan_dispatch<bool, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
} else {
|
||||
scan_dispatch<bool, bool>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
}
|
||||
case uint8:
|
||||
scan_dispatch<uint8_t, uint8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint16:
|
||||
scan_dispatch<uint16_t, uint16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint32:
|
||||
scan_dispatch<uint32_t, uint32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case uint64:
|
||||
scan_dispatch<uint64_t, uint64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int8:
|
||||
scan_dispatch<int8_t, int8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int16:
|
||||
scan_dispatch<int16_t, int16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int32:
|
||||
scan_dispatch<int32_t, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case int64:
|
||||
scan_dispatch<int64_t, int64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float16:
|
||||
scan_dispatch<float16_t, float16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float32:
|
||||
scan_dispatch<float, float>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case float64:
|
||||
scan_dispatch<double, double>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case bfloat16:
|
||||
scan_dispatch<bfloat16_t, bfloat16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
case complex64:
|
||||
scan_dispatch<complex64_t, complex64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_);
|
||||
break;
|
||||
break;
|
||||
}
|
||||
});
|
||||
case uint8:
|
||||
scan_dispatch<uint8_t, uint8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case uint16:
|
||||
scan_dispatch<uint16_t, uint16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case uint32:
|
||||
scan_dispatch<uint32_t, uint32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case uint64:
|
||||
scan_dispatch<uint64_t, uint64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case int8:
|
||||
scan_dispatch<int8_t, int8_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case int16:
|
||||
scan_dispatch<int16_t, int16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case int32:
|
||||
scan_dispatch<int32_t, int32_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case int64:
|
||||
scan_dispatch<int64_t, int64_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case float16:
|
||||
scan_dispatch<float16_t, float16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case float32:
|
||||
scan_dispatch<float, float>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case float64:
|
||||
scan_dispatch<double, double>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case bfloat16:
|
||||
scan_dispatch<bfloat16_t, bfloat16_t>(
|
||||
reduce_type_, in, out, axis_, reverse_, inclusive_, stream());
|
||||
break;
|
||||
case complex64:
|
||||
throw std::runtime_error("Scan ops do not support complex types yet");
|
||||
break;
|
||||
}
|
||||
if (copied) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(std::move(in));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -16,70 +16,51 @@ void select_op(
|
||||
const array& b,
|
||||
const array& c,
|
||||
array& out,
|
||||
Op op,
|
||||
Stream stream) {
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
||||
b = array::unsafe_weak_copy(b),
|
||||
c = array::unsafe_weak_copy(c),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
op,
|
||||
topt]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
ternary_op<bool, bool, bool, bool>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint8:
|
||||
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint16:
|
||||
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint32:
|
||||
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case uint64:
|
||||
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int8:
|
||||
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int16:
|
||||
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int32:
|
||||
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case int64:
|
||||
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case float16:
|
||||
ternary_op<bool, float16_t, float16_t, float16_t>(
|
||||
a, b, c, out, op, topt);
|
||||
break;
|
||||
case float32:
|
||||
ternary_op<bool, float, float, float>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case float64:
|
||||
ternary_op<bool, double, double, double>(a, b, c, out, op, topt);
|
||||
break;
|
||||
case bfloat16:
|
||||
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(
|
||||
a, b, c, out, op, topt);
|
||||
break;
|
||||
case complex64:
|
||||
ternary_op<bool, complex64_t, complex64_t, complex64_t>(
|
||||
a, b, c, out, op, topt);
|
||||
break;
|
||||
}
|
||||
});
|
||||
Op op) {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
|
||||
break;
|
||||
case uint8:
|
||||
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint16:
|
||||
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint32:
|
||||
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
|
||||
break;
|
||||
case uint64:
|
||||
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int8:
|
||||
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int16:
|
||||
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int32:
|
||||
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
|
||||
break;
|
||||
case int64:
|
||||
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
|
||||
break;
|
||||
case float16:
|
||||
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case float32:
|
||||
ternary_op<bool, float, float, float>(a, b, c, out, op);
|
||||
break;
|
||||
case float64:
|
||||
ternary_op<bool, double, double, double>(a, b, c, out, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
|
||||
break;
|
||||
case complex64:
|
||||
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -89,7 +70,7 @@ void Select::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
const auto& condition = inputs[0];
|
||||
const auto& a = inputs[1];
|
||||
const auto& b = inputs[2];
|
||||
select_op(condition, a, b, out, detail::Select(), stream());
|
||||
select_op(condition, a, b, out, detail::Select());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -17,7 +17,7 @@ struct ScalarT<float16_t, N> {
|
||||
#endif
|
||||
|
||||
template <>
|
||||
inline constexpr int max_size<float16_t> = N;
|
||||
static constexpr int max_size<float16_t> = N;
|
||||
|
||||
#define SIMD_FP16_DEFAULT_UNARY(op) \
|
||||
template <> \
|
||||
|
@@ -83,25 +83,25 @@ struct Simd {
|
||||
// Values chosen based on benchmarks on M3 Max
|
||||
// TODO: consider choosing these more optimally
|
||||
template <>
|
||||
inline constexpr int max_size<int8_t> = 16;
|
||||
static constexpr int max_size<int8_t> = 16;
|
||||
template <>
|
||||
inline constexpr int max_size<int16_t> = 16;
|
||||
static constexpr int max_size<int16_t> = 16;
|
||||
template <>
|
||||
inline constexpr int max_size<int> = 8;
|
||||
static constexpr int max_size<int> = 8;
|
||||
template <>
|
||||
inline constexpr int max_size<int64_t> = 4;
|
||||
static constexpr int max_size<int64_t> = 4;
|
||||
template <>
|
||||
inline constexpr int max_size<uint8_t> = 16;
|
||||
static constexpr int max_size<uint8_t> = 16;
|
||||
template <>
|
||||
inline constexpr int max_size<uint16_t> = 16;
|
||||
static constexpr int max_size<uint16_t> = 16;
|
||||
template <>
|
||||
inline constexpr int max_size<uint32_t> = 8;
|
||||
static constexpr int max_size<uint32_t> = 8;
|
||||
template <>
|
||||
inline constexpr int max_size<uint64_t> = 4;
|
||||
static constexpr int max_size<uint64_t> = 4;
|
||||
template <>
|
||||
inline constexpr int max_size<float> = 8;
|
||||
static constexpr int max_size<float> = 8;
|
||||
template <>
|
||||
inline constexpr int max_size<double> = 4;
|
||||
static constexpr int max_size<double> = 4;
|
||||
|
||||
#define SIMD_DEFAULT_UNARY(name, op) \
|
||||
template <typename T, int N> \
|
||||
|
@@ -87,45 +87,14 @@ DEFAULT_UNARY(cosh, std::cosh)
|
||||
DEFAULT_UNARY(expm1, std::expm1)
|
||||
DEFAULT_UNARY(floor, std::floor)
|
||||
DEFAULT_UNARY(log, std::log)
|
||||
DEFAULT_UNARY(log2, std::log2)
|
||||
DEFAULT_UNARY(log10, std::log10)
|
||||
DEFAULT_UNARY(log1p, std::log1p)
|
||||
DEFAULT_UNARY(sinh, std::sinh)
|
||||
DEFAULT_UNARY(sqrt, std::sqrt)
|
||||
DEFAULT_UNARY(tan, std::tan)
|
||||
DEFAULT_UNARY(tanh, std::tanh)
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log1p(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
auto x = in.value.real();
|
||||
auto y = in.value.imag();
|
||||
auto zabs = std::abs(in.value);
|
||||
auto theta = std::atan2(y, x + 1);
|
||||
if (zabs < 0.5) {
|
||||
auto r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return Simd<T, 1>{T{x, theta}};
|
||||
}
|
||||
return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
|
||||
} else {
|
||||
auto z0 = std::hypot(x + 1, y);
|
||||
return Simd<T, 1>{T{std::log(z0), theta}};
|
||||
}
|
||||
} else {
|
||||
return Simd<T, 1>{std::log1p(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> log2(Simd<T, 1> in) {
|
||||
if constexpr (is_complex<T>) {
|
||||
auto out = std::log(in.value);
|
||||
auto scale = decltype(out.real())(M_LN2);
|
||||
return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
|
||||
} else {
|
||||
return Simd<T, 1>{std::log2(in.value)};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Simd<T, 1> operator~(Simd<T, 1> in) {
|
||||
return ~in.value;
|
||||
|
@@ -119,12 +119,17 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto set_output = [s = stream(), &out](const array& x) {
|
||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc(x.data_size() * x.itemsize()),
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
@@ -132,7 +137,7 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_cpu(x, x_copy, CopyType::General, s);
|
||||
copy(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
@@ -141,6 +146,18 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = set_output(inputs[0]);
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
case uint8:
|
||||
case uint16:
|
||||
case uint32:
|
||||
case uint64:
|
||||
case int8:
|
||||
case int16:
|
||||
case int32:
|
||||
case int64:
|
||||
throw std::runtime_error(
|
||||
"Softmax is defined only for floating point types");
|
||||
break;
|
||||
case float32:
|
||||
softmax<float, float>(in, out, stream());
|
||||
break;
|
||||
@@ -161,9 +178,9 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
case float64:
|
||||
softmax<double, double>(in, out, stream());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[softmax] Only defined for floating point types.");
|
||||
case complex64:
|
||||
throw std::invalid_argument(
|
||||
"[Softmax] Not yet implemented for complex64");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@@ -105,11 +105,15 @@ struct StridedIterator {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void sort(array& out, int axis) {
|
||||
void sort(const array& in, array& out, int axis, Stream stream) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream);
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
@@ -123,20 +127,30 @@ void sort(array& out, int axis) {
|
||||
// Perform sorting in place
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<T>(),
|
||||
src_it = std::move(src_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
axis_stride]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed);
|
||||
src_it.step();
|
||||
}
|
||||
std::stable_sort(st, ed);
|
||||
src_it.step();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argsort(const array& in, array& out, int axis) {
|
||||
void argsort(const array& in, array& out, int axis, Stream stream) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
@@ -162,69 +176,99 @@ void argsort(const array& in, array& out, int axis) {
|
||||
in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<IdxT>(),
|
||||
in_it = std::move(in_it),
|
||||
out_it = std::move(out_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
in_stride,
|
||||
out_stride]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void partition(array& out, int axis, int kth) {
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + out.ndim() : axis;
|
||||
size_t in_size = out.size();
|
||||
size_t n_rows = in_size / out.shape(axis);
|
||||
void partition(const array& in, array& out, int axis, int kth, Stream stream) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype, stream);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t in_size = in.flags().contiguous ? in.data_size() : in.size();
|
||||
size_t n_rows = in_size / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = out.strides();
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
auto axis_stride = out.strides()[axis];
|
||||
int axis_size = out.shape(axis);
|
||||
auto axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition in place
|
||||
ContiguousIterator src_it(
|
||||
remaining_shape, remaining_strides, remaining_shape.size());
|
||||
auto out_ptr = out.data<T>();
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
src_it.step();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out_ptr = out.data<T>(),
|
||||
src_it = std::move(src_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
axis_stride,
|
||||
kth]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
T* data_ptr = out_ptr + src_it.loc;
|
||||
src_it.step();
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed);
|
||||
}
|
||||
std::nth_element(st, md, ed);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
void argpartition(
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
int kth,
|
||||
Stream stream) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
@@ -253,32 +297,42 @@ void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
ContiguousIterator out_it(
|
||||
out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
|
||||
|
||||
auto in_ptr = in.data<T>();
|
||||
auto out_ptr = out.data<IdxT>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in_ptr = in.data<T>(),
|
||||
out_ptr = out.data<IdxT>(),
|
||||
in_it = std::move(in_it),
|
||||
out_it = std::move(out_it),
|
||||
n_rows,
|
||||
axis_size,
|
||||
in_stride,
|
||||
out_stride,
|
||||
kth]() mutable {
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
const T* data_ptr = in_ptr + in_it.loc;
|
||||
IdxT* idx_ptr = out_ptr + out_it.loc;
|
||||
in_it.step();
|
||||
out_it.step();
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator md(idx_ptr, out_stride, kth);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator md(idx_ptr, out_stride, kth);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
std::nth_element(st, md, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -287,188 +341,144 @@ void ArgSort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argsort<bool>(in, out, axis_);
|
||||
case uint8:
|
||||
return argsort<uint8_t>(in, out, axis_);
|
||||
case uint16:
|
||||
return argsort<uint16_t>(in, out, axis_);
|
||||
case uint32:
|
||||
return argsort<uint32_t>(in, out, axis_);
|
||||
case uint64:
|
||||
return argsort<uint64_t>(in, out, axis_);
|
||||
case int8:
|
||||
return argsort<int8_t>(in, out, axis_);
|
||||
case int16:
|
||||
return argsort<int16_t>(in, out, axis_);
|
||||
case int32:
|
||||
return argsort<int32_t>(in, out, axis_);
|
||||
case int64:
|
||||
return argsort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_);
|
||||
case float64:
|
||||
return argsort<double>(in, out, axis_);
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
return argsort<bfloat16_t>(in, out, axis_);
|
||||
case complex64:
|
||||
return argsort<complex64_t>(in, out, axis_);
|
||||
}
|
||||
});
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argsort<bool>(in, out, axis_, stream());
|
||||
case uint8:
|
||||
return argsort<uint8_t>(in, out, axis_, stream());
|
||||
case uint16:
|
||||
return argsort<uint16_t>(in, out, axis_, stream());
|
||||
case uint32:
|
||||
return argsort<uint32_t>(in, out, axis_, stream());
|
||||
case uint64:
|
||||
return argsort<uint64_t>(in, out, axis_, stream());
|
||||
case int8:
|
||||
return argsort<int8_t>(in, out, axis_, stream());
|
||||
case int16:
|
||||
return argsort<int16_t>(in, out, axis_, stream());
|
||||
case int32:
|
||||
return argsort<int32_t>(in, out, axis_, stream());
|
||||
case int64:
|
||||
return argsort<int64_t>(in, out, axis_, stream());
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_, stream());
|
||||
case float64:
|
||||
return argsort<double>(in, out, axis_, stream());
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_, stream());
|
||||
case bfloat16:
|
||||
return argsort<bfloat16_t>(in, out, axis_, stream());
|
||||
case complex64:
|
||||
return argsort<complex64_t>(in, out, axis_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch(
|
||||
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(out, axis_);
|
||||
case uint8:
|
||||
return sort<uint8_t>(out, axis_);
|
||||
case uint16:
|
||||
return sort<uint16_t>(out, axis_);
|
||||
case uint32:
|
||||
return sort<uint32_t>(out, axis_);
|
||||
case uint64:
|
||||
return sort<uint64_t>(out, axis_);
|
||||
case int8:
|
||||
return sort<int8_t>(out, axis_);
|
||||
case int16:
|
||||
return sort<int16_t>(out, axis_);
|
||||
case int32:
|
||||
return sort<int32_t>(out, axis_);
|
||||
case int64:
|
||||
return sort<int64_t>(out, axis_);
|
||||
case float32:
|
||||
return sort<float>(out, axis_);
|
||||
case float64:
|
||||
return sort<double>(out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(out, axis_);
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(out, axis_);
|
||||
case complex64:
|
||||
return sort<complex64_t>(out, axis_);
|
||||
}
|
||||
});
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(in, out, axis_, stream());
|
||||
case uint8:
|
||||
return sort<uint8_t>(in, out, axis_, stream());
|
||||
case uint16:
|
||||
return sort<uint16_t>(in, out, axis_, stream());
|
||||
case uint32:
|
||||
return sort<uint32_t>(in, out, axis_, stream());
|
||||
case uint64:
|
||||
return sort<uint64_t>(in, out, axis_, stream());
|
||||
case int8:
|
||||
return sort<int8_t>(in, out, axis_, stream());
|
||||
case int16:
|
||||
return sort<int16_t>(in, out, axis_, stream());
|
||||
case int32:
|
||||
return sort<int32_t>(in, out, axis_, stream());
|
||||
case int64:
|
||||
return sort<int64_t>(in, out, axis_, stream());
|
||||
case float32:
|
||||
return sort<float>(in, out, axis_, stream());
|
||||
case float64:
|
||||
return sort<double>(in, out, axis_, stream());
|
||||
case float16:
|
||||
return sort<float16_t>(in, out, axis_, stream());
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(in, out, axis_, stream());
|
||||
case complex64:
|
||||
return sort<complex64_t>(in, out, axis_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_input_array(in);
|
||||
encoder.set_input_array(out);
|
||||
encoder.dispatch([in = array::unsafe_weak_copy(in),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
kth_ = kth_]() mutable {
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argpartition<bool>(in, out, axis_, kth_);
|
||||
case uint8:
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_);
|
||||
case uint16:
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_);
|
||||
case uint32:
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_);
|
||||
case uint64:
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_);
|
||||
case int8:
|
||||
return argpartition<int8_t>(in, out, axis_, kth_);
|
||||
case int16:
|
||||
return argpartition<int16_t>(in, out, axis_, kth_);
|
||||
case int32:
|
||||
return argpartition<int32_t>(in, out, axis_, kth_);
|
||||
case int64:
|
||||
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_);
|
||||
case float64:
|
||||
return argpartition<double>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_);
|
||||
case complex64:
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_);
|
||||
}
|
||||
});
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argpartition<bool>(in, out, axis_, kth_, stream());
|
||||
case uint8:
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_, stream());
|
||||
case uint16:
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_, stream());
|
||||
case uint32:
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_, stream());
|
||||
case uint64:
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_, stream());
|
||||
case int8:
|
||||
return argpartition<int8_t>(in, out, axis_, kth_, stream());
|
||||
case int16:
|
||||
return argpartition<int16_t>(in, out, axis_, kth_, stream());
|
||||
case int32:
|
||||
return argpartition<int32_t>(in, out, axis_, kth_, stream());
|
||||
case int64:
|
||||
return argpartition<int64_t>(in, out, axis_, kth_, stream());
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_, stream());
|
||||
case float64:
|
||||
return argpartition<double>(in, out, axis_, kth_, stream());
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_, stream());
|
||||
case bfloat16:
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_, stream());
|
||||
case complex64:
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
void Partition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
axis_ = axis_,
|
||||
kth_ = kth_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return partition<bool>(out, axis_, kth_);
|
||||
case uint8:
|
||||
return partition<uint8_t>(out, axis_, kth_);
|
||||
case uint16:
|
||||
return partition<uint16_t>(out, axis_, kth_);
|
||||
case uint32:
|
||||
return partition<uint32_t>(out, axis_, kth_);
|
||||
case uint64:
|
||||
return partition<uint64_t>(out, axis_, kth_);
|
||||
case int8:
|
||||
return partition<int8_t>(out, axis_, kth_);
|
||||
case int16:
|
||||
return partition<int16_t>(out, axis_, kth_);
|
||||
case int32:
|
||||
return partition<int32_t>(out, axis_, kth_);
|
||||
case int64:
|
||||
return partition<int64_t>(out, axis_, kth_);
|
||||
case float32:
|
||||
return partition<float>(out, axis_, kth_);
|
||||
case float64:
|
||||
return partition<double>(out, axis_, kth_);
|
||||
case float16:
|
||||
return partition<float16_t>(out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return partition<bfloat16_t>(out, axis_, kth_);
|
||||
case complex64:
|
||||
return partition<complex64_t>(out, axis_, kth_);
|
||||
}
|
||||
});
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return partition<bool>(in, out, axis_, kth_, stream());
|
||||
case uint8:
|
||||
return partition<uint8_t>(in, out, axis_, kth_, stream());
|
||||
case uint16:
|
||||
return partition<uint16_t>(in, out, axis_, kth_, stream());
|
||||
case uint32:
|
||||
return partition<uint32_t>(in, out, axis_, kth_, stream());
|
||||
case uint64:
|
||||
return partition<uint64_t>(in, out, axis_, kth_, stream());
|
||||
case int8:
|
||||
return partition<int8_t>(in, out, axis_, kth_, stream());
|
||||
case int16:
|
||||
return partition<int16_t>(in, out, axis_, kth_, stream());
|
||||
case int32:
|
||||
return partition<int32_t>(in, out, axis_, kth_, stream());
|
||||
case int64:
|
||||
return partition<int64_t>(in, out, axis_, kth_, stream());
|
||||
case float32:
|
||||
return partition<float>(in, out, axis_, kth_, stream());
|
||||
case float64:
|
||||
return partition<double>(in, out, axis_, kth_, stream());
|
||||
case float16:
|
||||
return partition<float16_t>(in, out, axis_, kth_, stream());
|
||||
case bfloat16:
|
||||
return partition<bfloat16_t>(in, out, axis_, kth_, stream());
|
||||
case complex64:
|
||||
return partition<complex64_t>(in, out, axis_, kth_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -31,7 +31,7 @@ void svd_impl(
|
||||
|
||||
// lapack clobbers the input, so we have to make a copy.
|
||||
array in(a.shape(), a.dtype(), nullptr, {});
|
||||
copy_cpu(
|
||||
copy(
|
||||
a,
|
||||
in,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
@@ -50,9 +50,9 @@ void svd_impl(
|
||||
array& s = outputs[1];
|
||||
array& vt = outputs[2];
|
||||
|
||||
u.set_data(allocator::malloc(u.nbytes()));
|
||||
s.set_data(allocator::malloc(s.nbytes()));
|
||||
vt.set_data(allocator::malloc(vt.nbytes()));
|
||||
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||
|
||||
encoder.set_output_array(u);
|
||||
encoder.set_output_array(s);
|
||||
@@ -64,7 +64,7 @@ void svd_impl(
|
||||
} else {
|
||||
array& s = outputs[0];
|
||||
|
||||
s.set_data(allocator::malloc(s.nbytes()));
|
||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||
|
||||
encoder.set_output_array(s);
|
||||
|
||||
@@ -91,7 +91,7 @@ void svd_impl(
|
||||
|
||||
// Will contain the indices of eigenvectors that failed to converge (not
|
||||
// used here but required by lapack).
|
||||
auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)};
|
||||
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
|
||||
|
||||
static const int lwork_query = -1;
|
||||
|
||||
@@ -132,7 +132,7 @@ void svd_impl(
|
||||
}
|
||||
|
||||
const int lwork = workspace_dimension;
|
||||
auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)};
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)};
|
||||
|
||||
// Loop over matrices.
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user