mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 08:24:39 +08:00
Compare commits
93 Commits
compile-te
...
v0.3.0
Author | SHA1 | Date | |
---|---|---|---|
![]() |
bf7cd29970 | ||
![]() |
a000d2288c | ||
![]() |
165abf0e4c | ||
![]() |
818cda16bc | ||
![]() |
85143fecdd | ||
![]() |
35431a4ac8 | ||
![]() |
ccf1645995 | ||
![]() |
1a48713d32 | ||
![]() |
1eb04aa23f | ||
![]() |
0c65517e91 | ||
![]() |
2fdc2462c3 | ||
![]() |
be6e9d6a9f | ||
![]() |
e54cbb7ba6 | ||
![]() |
40c108766b | ||
![]() |
4cc70290f7 | ||
![]() |
74caa68d02 | ||
![]() |
3756381358 | ||
![]() |
d12573daa6 | ||
![]() |
0dbc4c7547 | ||
![]() |
06072601ce | ||
![]() |
11d2c8f7a1 | ||
![]() |
7f3f8d8f8d | ||
![]() |
b96be943dc | ||
![]() |
b670485185 | ||
![]() |
b57bd0488d | ||
![]() |
221f8d3fc2 | ||
![]() |
5c03efaf29 | ||
![]() |
7dccd42133 | ||
![]() |
1b97b2958b | ||
![]() |
e5e816a5ef | ||
![]() |
28eac18571 | ||
![]() |
5fd11c347d | ||
![]() |
ef73393a19 | ||
![]() |
ea406d5e33 | ||
![]() |
146bd69470 | ||
![]() |
316ff490b3 | ||
![]() |
d40a04f8dc | ||
![]() |
d75ae52ecd | ||
![]() |
31fea3758e | ||
![]() |
e319383ef9 | ||
![]() |
5c3ac52dd7 | ||
![]() |
ebfd3618b0 | ||
![]() |
11a9fd40f0 | ||
![]() |
4fd2fb84a6 | ||
![]() |
9852af1a19 | ||
![]() |
16750f3c51 | ||
![]() |
95b5fb8245 | ||
![]() |
83f63f2184 | ||
![]() |
cb6156d35d | ||
![]() |
506d43035c | ||
![]() |
36cff34701 | ||
![]() |
e88e474fd1 | ||
![]() |
601c6d6aa8 | ||
![]() |
ba8d6bf365 | ||
![]() |
4a5f3b21bb | ||
![]() |
fcc5ac1c64 | ||
![]() |
bad67fec37 | ||
![]() |
199aebcf77 | ||
![]() |
0de5988f92 | ||
![]() |
143e2690d5 | ||
![]() |
375446453e | ||
![]() |
1895d34c20 | ||
![]() |
09b9275027 | ||
![]() |
d3a9005454 | ||
![]() |
3f7aba8498 | ||
![]() |
65d0b8df9f | ||
![]() |
3c2f192345 | ||
![]() |
37d98ba6ff | ||
![]() |
8993382aaa | ||
![]() |
07f35c9d8a | ||
![]() |
bf17ab5002 | ||
![]() |
8fa6b322b9 | ||
![]() |
874b739f3c | ||
![]() |
077c1ee64a | ||
![]() |
2463496471 | ||
![]() |
87b7fa9ba2 | ||
![]() |
624065c074 | ||
![]() |
f27ec5e097 | ||
![]() |
f30e63353a | ||
![]() |
4fe2fa2a64 | ||
![]() |
37fc9db82c | ||
![]() |
755dcf6137 | ||
![]() |
6b4b30e3fc | ||
![]() |
86e0c79467 | ||
![]() |
98c37d3a22 | ||
![]() |
f326dd8334 | ||
![]() |
6d3bee3364 | ||
![]() |
ecb174ca9d | ||
![]() |
7a34e46677 | ||
![]() |
92c22c1ea3 | ||
![]() |
d52383367a | ||
![]() |
363d3add6d | ||
![]() |
b207c2c86b |
@@ -1,5 +1,8 @@
|
||||
version: 2.1
|
||||
|
||||
orbs:
|
||||
apple: ml-explore/pr-approval@0.1.0
|
||||
|
||||
parameters:
|
||||
nightly_build:
|
||||
type: boolean
|
||||
@@ -7,6 +10,9 @@ parameters:
|
||||
weekly_build:
|
||||
type: boolean
|
||||
default: false
|
||||
test_release:
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
linux_build_and_test:
|
||||
@@ -26,18 +32,23 @@ jobs:
|
||||
command: |
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
- run:
|
||||
name: Build python package
|
||||
name: Install Python package
|
||||
command: |
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py build_ext --inplace
|
||||
CMAKE_ARGS="-DMLX_BUILD_METAL=OFF" CMAKE_BUILD_PARALLEL_LEVEL="" python3 setup.py develop
|
||||
- run:
|
||||
name: Run the python tests
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
python3 -m unittest discover python/tests
|
||||
python3 setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
python3 -m unittest discover python/tests -v
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
@@ -52,169 +63,180 @@ jobs:
|
||||
command: ./build/tests/tests
|
||||
|
||||
mac_build_and_test:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
macos:
|
||||
xcode: "15.2.0"
|
||||
resource_class: macos.m1.large.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=3.9
|
||||
conda activate runner-env
|
||||
brew install python@3.9
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install torch
|
||||
pip install tensorflow
|
||||
pip install unittest-xml-reporting
|
||||
- run:
|
||||
name: Build python package
|
||||
name: Install Python package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext --inplace
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py develop
|
||||
source env/bin/activate
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . -v
|
||||
- run:
|
||||
name: Run the python tests
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
DEVICE=gpu python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
source env/bin/activate
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Run Python tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
# TODO: Reenable when extension api becomes stable
|
||||
# - run:
|
||||
# name: Build example extension
|
||||
# command: |
|
||||
# eval "$(conda shell.bash hook)"
|
||||
# conda activate runner-env
|
||||
# cd examples/extensions && python -m pip install .
|
||||
# cd examples/extensions && python3.11 -m pip install .
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
name: Build CPP only
|
||||
command: |
|
||||
source env/bin/activate
|
||||
mkdir -p build && cd build && cmake .. && make -j
|
||||
- run:
|
||||
name: Run CPP tests
|
||||
command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
DEVICE=cpu ./build/tests/tests
|
||||
|
||||
build_release:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
macos_version:
|
||||
xcode_version:
|
||||
type: string
|
||||
default: "14"
|
||||
default: "15.2.0"
|
||||
build_env:
|
||||
type: string
|
||||
default: ""
|
||||
macos:
|
||||
xcode: << parameters.xcode_version >>
|
||||
resource_class: macos.m1.large.gen1
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
||||
conda activate runner-env
|
||||
brew install python@<< parameters.python_version >>
|
||||
python<< parameters.python_version >> -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install --upgrade setuptools
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install twine
|
||||
pip install build
|
||||
- run:
|
||||
name: Build package
|
||||
name: Install Python package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
PYPI_RELEASE=1 \
|
||||
source env/bin/activate
|
||||
DEV_RELEASE=1 \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py bdist_wheel
|
||||
twine upload dist/* --repository mlx
|
||||
pip install . -v
|
||||
- run:
|
||||
name: Generate package stubs
|
||||
command: |
|
||||
source env/bin/activate
|
||||
python setup.py generate_stubs
|
||||
- run:
|
||||
name: Build Python package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
<< parameters.build_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python -m build -w
|
||||
- when:
|
||||
condition: << parameters.build_env >>
|
||||
steps:
|
||||
- run:
|
||||
name: Upload package
|
||||
command: |
|
||||
source env/bin/activate
|
||||
twine upload dist/*
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_dev_release:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
build_linux_test_release:
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
macos_version:
|
||||
extra_env:
|
||||
type: string
|
||||
default: "14"
|
||||
default: "DEV_RELEASE=1"
|
||||
docker:
|
||||
- image: ubuntu:20.04
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
name: Build wheel
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
||||
conda activate runner-env
|
||||
PYTHON=python<< parameters.python_version >>
|
||||
apt-get update
|
||||
apt-get upgrade -y
|
||||
DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata
|
||||
apt-get install -y apt-utils
|
||||
apt-get install -y software-properties-common
|
||||
add-apt-repository -y ppa:deadsnakes/ppa
|
||||
apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
||||
apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
||||
apt-get install -y build-essential git
|
||||
$PYTHON -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install --upgrade setuptools
|
||||
pip install pybind11-stubgen
|
||||
pip install numpy
|
||||
pip install twine
|
||||
- run:
|
||||
name: Build package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
DEV_RELEASE=1 \
|
||||
pip install auditwheel
|
||||
pip install patchelf
|
||||
pip install build
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py bdist_wheel
|
||||
twine upload dist/* --repository mlx
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
|
||||
build_package:
|
||||
machine: true
|
||||
resource_class: ml-explore/m-builder
|
||||
parameters:
|
||||
python_version:
|
||||
type: string
|
||||
default: "3.9"
|
||||
macos_version:
|
||||
type: string
|
||||
default: "14"
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
rm -r $CONDA_PREFIX/envs/runner-env
|
||||
conda create -y -n runner-env python=<< parameters.python_version >>
|
||||
conda activate runner-env
|
||||
pip install --upgrade cmake
|
||||
pip install --upgrade pybind11[global]
|
||||
pip install numpy
|
||||
pip install twine
|
||||
- run:
|
||||
name: Build package
|
||||
command: |
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate runner-env
|
||||
DEVELOPER_DIR=$(developer_dir_macos_<< parameters.macos_version >>) \
|
||||
pip install . -v
|
||||
python setup.py generate_stubs
|
||||
<< parameters.extra_env >> \
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" \
|
||||
python setup.py bdist_wheel
|
||||
python -m build --wheel
|
||||
auditwheel show dist/*
|
||||
auditwheel repair dist/* --plat manylinux_2_31_x86_64
|
||||
- store_artifacts:
|
||||
path: dist/
|
||||
path: wheelhouse/
|
||||
|
||||
workflows:
|
||||
build_and_test:
|
||||
when:
|
||||
and:
|
||||
- matches:
|
||||
pattern: "^(?!pull/)[-\\w]+$"
|
||||
value: << pipeline.git.branch >>
|
||||
- not: << pipeline.parameters.nightly_build >>
|
||||
- not: << pipeline.parameters.weekly_build >>
|
||||
- not: << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- linux_build_and_test
|
||||
- mac_build_and_test
|
||||
- linux_build_and_test
|
||||
- build_release:
|
||||
filters:
|
||||
tags:
|
||||
@@ -224,20 +246,53 @@ workflows:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
xcode_version: ["14.3.1", "15.2.0"]
|
||||
build_env: ["PYPI_RELEASE=1"]
|
||||
prb:
|
||||
when:
|
||||
matches:
|
||||
pattern: "^pull/\\d+(/head)?$"
|
||||
value: << pipeline.git.branch >>
|
||||
jobs:
|
||||
- hold:
|
||||
type: approval
|
||||
- apple/authenticate:
|
||||
context: pr-approval
|
||||
- mac_build_and_test:
|
||||
requires: [ hold ]
|
||||
- linux_build_and_test:
|
||||
requires: [ hold ]
|
||||
nightly_build:
|
||||
when: << pipeline.parameters.nightly_build >>
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.nightly_build >>
|
||||
jobs:
|
||||
- build_package:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
xcode_version: ["14.3.1", "15.2.0"]
|
||||
weekly_build:
|
||||
when: << pipeline.parameters.weekly_build >>
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.weekly_build >>
|
||||
jobs:
|
||||
- build_dev_release:
|
||||
- build_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
macos_version: ["13", "14"]
|
||||
xcode_version: ["14.3.1", "15.2.0"]
|
||||
build_env: ["DEV_RELEASE=1"]
|
||||
linux_test_release:
|
||||
when:
|
||||
and:
|
||||
- equal: [ main, << pipeline.git.branch >> ]
|
||||
- << pipeline.parameters.test_release >>
|
||||
jobs:
|
||||
- build_linux_test_release:
|
||||
matrix:
|
||||
parameters:
|
||||
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
extra_env: ["PYPI_RELEASE=1"]
|
||||
|
@@ -5,11 +5,11 @@ repos:
|
||||
- id: clang-format
|
||||
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 23.12.1
|
||||
rev: 24.2.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
|
@@ -10,8 +10,8 @@ MLX was developed with contributions from the following individuals:
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.0.10)
|
||||
set(MLX_VERSION 0.3.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -31,13 +31,13 @@ if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
|
||||
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
|
||||
message(FATAL_ERROR
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, check the build"
|
||||
" documentation for possible fixes: "
|
||||
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source")
|
||||
elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
||||
message(WARNING
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
message(WARNING
|
||||
"Building for x86_64 on macOS is not supported."
|
||||
" If you are on an Apple silicon system, "
|
||||
" make sure you are building for arm64.")
|
||||
elseif(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "arm64")
|
||||
@@ -75,7 +75,7 @@ elseif (MLX_BUILD_METAL)
|
||||
COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
|
||||
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
@@ -123,16 +123,27 @@ else()
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
$ENV{BLAS_HOME}/include)
|
||||
message(STATUS ${BLAS_LIBRARIES})
|
||||
message(STATUS ${BLAS_INCLUDE_DIRS})
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
find_package(LAPACK REQUIRED)
|
||||
if (NOT LAPACK_FOUND)
|
||||
message(FATAL_ERROR "Must have LAPACK installed")
|
||||
endif()
|
||||
find_path(LAPACK_INCLUDE_DIRS lapacke.h
|
||||
/usr/include
|
||||
/usr/local/include)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
||||
target_include_directories(
|
||||
mlx
|
||||
mlx
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
|
@@ -1,3 +1,4 @@
|
||||
include CMakeLists.txt
|
||||
recursive-include mlx/ *
|
||||
include python/src/*
|
||||
include python/mlx/py.typed # support type hinting as in PEP-561
|
||||
|
12
README.md
12
README.md
@@ -6,8 +6,8 @@
|
||||
|
||||
[](https://circleci.com/gh/ml-explore/mlx)
|
||||
|
||||
MLX is an array framework for machine learning on Apple silicon, brought to you
|
||||
by Apple machine learning research.
|
||||
MLX is an array framework for machine learning research on Apple silicon,
|
||||
brought to you by Apple machine learning research.
|
||||
|
||||
Some key features of MLX include:
|
||||
|
||||
@@ -68,10 +68,18 @@ in the documentation.
|
||||
|
||||
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
||||
|
||||
**With `pip`**:
|
||||
|
||||
```
|
||||
pip install mlx
|
||||
```
|
||||
|
||||
**With `conda`**:
|
||||
|
||||
```
|
||||
conda install -c conda-forge mlx
|
||||
```
|
||||
|
||||
Checkout the
|
||||
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
||||
for more information on building the C++ and Python APIs from source.
|
||||
|
@@ -72,6 +72,9 @@ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
||||
|
||||
|
||||
quant_matmul = {
|
||||
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
|
||||
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
|
||||
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
|
||||
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
||||
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
||||
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
||||
@@ -84,6 +87,15 @@ quant_matmul = {
|
||||
"quant_matmul_128_8": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=8
|
||||
),
|
||||
"quant_matmul_t_32_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=32, bits=2
|
||||
),
|
||||
"quant_matmul_t_32_4": partial(
|
||||
_quant_matmul, transpose=True, group_size=32, bits=4
|
||||
),
|
||||
"quant_matmul_t_32_8": partial(
|
||||
_quant_matmul, transpose=True, group_size=32, bits=8
|
||||
),
|
||||
"quant_matmul_t_64_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=2
|
||||
),
|
||||
|
@@ -80,10 +80,8 @@ if __name__ == "__main__":
|
||||
_filter = make_predicate(args.filter, args.negative_filter)
|
||||
|
||||
if args.mlx_dtypes:
|
||||
compare_filtered = (
|
||||
lambda x: compare_mlx_dtypes(
|
||||
x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]
|
||||
)
|
||||
compare_filtered = lambda x: (
|
||||
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
|
||||
if _filter(x)
|
||||
else None
|
||||
)
|
||||
|
53
benchmarks/python/gather_bench.py
Normal file
53
benchmarks/python/gather_bench.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
from time import time
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_gather_mlx(x_shape, idx_shape):
|
||||
def gather(x, idx):
|
||||
mx.eval(x[idx])
|
||||
|
||||
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
|
||||
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||
|
||||
runtime = measure_runtime(gather, x=x, idx=idx)
|
||||
print(f"MLX: {runtime:.3f}ms")
|
||||
|
||||
|
||||
def benchmark_gather_torch(x_shape, idx_shape, device):
|
||||
def gather(x, idx, device):
|
||||
_ = x[idx]
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
idx_shapes = [(1_000_000,), (100_000,), ()]
|
||||
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
|
||||
|
||||
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_gather_mlx(x_shape, idx_shape)
|
||||
benchmark_gather_torch(x_shape, idx_shape, device=device)
|
@@ -1,198 +0,0 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import linen as nn
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
dims: int
|
||||
traditional: bool = False
|
||||
|
||||
def _compute_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., : self.dims // 2]
|
||||
x2 = x[..., self.dims // 2 : self.dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
||||
else:
|
||||
rx = jnp.concatenate([rx1, rx2], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
raise NotImplementedError(
|
||||
"RoPE doesn't implement partial traditional application"
|
||||
)
|
||||
|
||||
rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
||||
|
||||
return rx
|
||||
|
||||
@staticmethod
|
||||
def create_cos_sin_theta(
|
||||
N: int,
|
||||
D: int,
|
||||
offset: int = 0,
|
||||
base: float = 10000,
|
||||
dtype=jnp.float32,
|
||||
):
|
||||
D = D // 2
|
||||
positions = jnp.arange(offset, N, dtype=dtype)
|
||||
freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D))
|
||||
theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1))
|
||||
costheta = jnp.cos(theta)
|
||||
sintheta = jnp.sin(theta)
|
||||
|
||||
return costheta, sintheta
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = x.reshape((-1, shape[-2], shape[-1]))
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return rx.reshape(shape)
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
dims: int
|
||||
num_heads: int
|
||||
dtype: jnp.dtype
|
||||
|
||||
def setup(self):
|
||||
num_heads = self.num_heads
|
||||
dims = self.dims
|
||||
|
||||
self.rope = RoPE(dims // num_heads, True)
|
||||
self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||
self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||
self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||
self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
||||
keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
||||
values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = jnp.concatenate([key_cache, keys], axis=2)
|
||||
values = jnp.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose((0, 1, 3, 2))
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = jax.nn.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1))
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
dims: int
|
||||
mlp_dims: int
|
||||
num_heads: int
|
||||
dtype: jnp.dtype
|
||||
|
||||
def setup(self):
|
||||
dims = self.dims
|
||||
mlp_dims = self.mlp_dims
|
||||
num_heads = self.num_heads
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads, dtype)
|
||||
|
||||
self.norm1 = nn.RMSNorm(param_dtype=self.dtype)
|
||||
self.norm2 = nn.RMSNorm(param_dtype=self.dtype)
|
||||
|
||||
self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
|
||||
self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype)
|
||||
self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = jax.nn.silu(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
def measure(model, x, cache):
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
jax.block_until_ready((y, c))
|
||||
|
||||
start = time.time()
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
jax.block_until_ready((y, c))
|
||||
|
||||
end = time.time()
|
||||
return (end - start) * 1000 / 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
H = 32
|
||||
D = 4096
|
||||
F = 43 * 256
|
||||
C = 1000
|
||||
dtype = jnp.float16
|
||||
|
||||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
|
||||
|
||||
x = jax.random.normal(k1, (1, 1, D), dtype)
|
||||
cache = [
|
||||
jax.random.normal(k2, [1, H, C, D // H], dtype),
|
||||
jax.random.normal(k3, [1, H, C, D // H], dtype),
|
||||
]
|
||||
|
||||
layer = LlamaEncoderLayer(D, F, H, dtype=dtype)
|
||||
params = layer.init(k4, x, mask=None, cache=cache)["params"]
|
||||
|
||||
@jax.jit
|
||||
def model_fn(x, mask, cache):
|
||||
return layer.apply({"params": params}, x, mask=mask, cache=cache)
|
||||
|
||||
T = measure(model_fn, x, cache)
|
||||
|
||||
print("Time per layer per token:", T, "ms")
|
||||
print("Lower bound total time per token:", T * 32, "ms")
|
@@ -1,118 +0,0 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.rope = nn.RoPE(dims // num_heads, True)
|
||||
self.query_proj = nn.Linear(dims, dims, False)
|
||||
self.key_proj = nn.Linear(dims, dims, False)
|
||||
self.value_proj = nn.Linear(dims, dims, False)
|
||||
self.out_proj = nn.Linear(dims, dims, False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
|
||||
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = nn.RMSNorm(dims)
|
||||
self.norm2 = nn.RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = a * mx.sigmoid(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
def measure(model, x, cache):
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
mx.eval(y, c)
|
||||
|
||||
start = time.time()
|
||||
rs = []
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
rs.append((y, c))
|
||||
mx.eval(rs)
|
||||
end = time.time()
|
||||
|
||||
return (end - start) * 1000 / 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
H = 32
|
||||
D = 4096
|
||||
F = 43 * 256
|
||||
C = 1000
|
||||
mx.set_default_device(mx.gpu)
|
||||
dtype = mx.float16
|
||||
|
||||
layer = LlamaEncoderLayer(D, F, H)
|
||||
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
|
||||
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
|
||||
x = mx.random.normal([1, 1, D], dtype=dtype)
|
||||
cache = [
|
||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
||||
]
|
||||
mx.eval(x, cache)
|
||||
|
||||
T = measure(layer, x, cache)
|
||||
|
||||
print("Time per layer per token:", T, "ms")
|
||||
print("Lower bound total time per token:", T * 32, "ms")
|
@@ -1,199 +0,0 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.mps
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def sync_if_needed(x):
|
||||
if x.device != torch.device("cpu"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
def __init__(self, dims: int, traditional: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.traditional = traditional
|
||||
|
||||
def _compute_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., : self.dims // 2]
|
||||
x2 = x[..., self.dims // 2 : self.dims]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1)
|
||||
else:
|
||||
rx = torch.cat([rx1, rx2], dim=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def _compute_traditional_rope(self, costheta, sintheta, x):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * costheta - x2 * sintheta
|
||||
rx2 = x1 * sintheta + x2 * costheta
|
||||
|
||||
if self.dims < x.shape[-1]:
|
||||
raise NotImplementedError(
|
||||
"RoPE doesn't implement partial traditional application"
|
||||
)
|
||||
|
||||
rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
|
||||
|
||||
return rx
|
||||
|
||||
def forward(self, x, offset: int = 0):
|
||||
shape = x.shape
|
||||
x = x.view(-1, shape[-2], shape[-1])
|
||||
N = x.shape[1] + offset
|
||||
costheta, sintheta = RoPE.create_cos_sin_theta(
|
||||
N, self.dims, offset=offset, device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
rope = (
|
||||
self._compute_traditional_rope if self.traditional else self._compute_rope
|
||||
)
|
||||
rx = rope(costheta, sintheta, x)
|
||||
|
||||
return rx.view(*shape)
|
||||
|
||||
@staticmethod
|
||||
def create_cos_sin_theta(
|
||||
N: int,
|
||||
D: int,
|
||||
offset: int = 0,
|
||||
base: float = 10000,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
):
|
||||
D = D // 2
|
||||
positions = torch.arange(offset, N, dtype=dtype, device=device)
|
||||
freqs = torch.exp(
|
||||
-torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D)
|
||||
)
|
||||
theta = positions.view(-1, 1) * freqs.view(1, -1)
|
||||
costheta = torch.cos(theta)
|
||||
sintheta = torch.sin(theta)
|
||||
|
||||
return costheta, sintheta
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, epsilon: float = 1e-6):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones((dims,)))
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon)
|
||||
return self.gamma * x * n
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.rope = RoPE(dims // num_heads, True)
|
||||
self.query_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.key_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.value_proj = nn.Linear(dims, dims, bias=False)
|
||||
self.out_proj = nn.Linear(dims, dims, bias=False)
|
||||
|
||||
def forward(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||
keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||
values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = torch.cat([key_cache, keys], dim=2)
|
||||
values = torch.cat([value_cache, values], dim=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = math.sqrt(1 / queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.permute(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = torch.softmax(scores, dim=-1)
|
||||
values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = RMSNorm(dims)
|
||||
self.norm2 = RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
|
||||
|
||||
def forward(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = torch.nn.functional.silu(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def measure(model, x, cache):
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
sync_if_needed(x)
|
||||
|
||||
start = time.time()
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
sync_if_needed(x)
|
||||
end = time.time()
|
||||
return (end - start) * 1000 / 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
H = 32
|
||||
D = 4096
|
||||
F = 43 * 256
|
||||
C = 1000
|
||||
device = torch.device("mps")
|
||||
dtype = torch.float16
|
||||
|
||||
layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype)
|
||||
x = torch.randn(1, 1, D).to(device).to(dtype)
|
||||
cache = [
|
||||
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
||||
torch.randn(1, H, C, D // H).to(device).to(dtype),
|
||||
]
|
||||
|
||||
T = measure(layer, x, cache)
|
||||
|
||||
print("Time per layer per token:", T, "ms")
|
||||
print("Lower bound total time per token:", T * 32, "ms")
|
35
benchmarks/python/rope_bench.py
Normal file
35
benchmarks/python/rope_bench.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def time_rope():
|
||||
rope = nn.RoPE(4096)
|
||||
|
||||
# vec
|
||||
x = mx.random.uniform(shape=(1, 4096)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_vec(x):
|
||||
for _ in range(32):
|
||||
x = rope(x)
|
||||
return x
|
||||
|
||||
time_fn(rope_vec, x)
|
||||
|
||||
# matrix
|
||||
x = mx.random.uniform(shape=(1024, 4096)).astype(mx.float16)
|
||||
mx.eval(x)
|
||||
|
||||
def rope_mat(x):
|
||||
for _ in range(32):
|
||||
x = rope(x)
|
||||
return x
|
||||
|
||||
time_fn(rope_mat, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_rope()
|
56
benchmarks/python/scatter_bench.py
Normal file
56
benchmarks/python/scatter_bench.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
|
||||
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
|
||||
def scatter(dst, x, idx):
|
||||
dst[idx] = x
|
||||
mx.eval(dst)
|
||||
|
||||
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape)
|
||||
x = mx.random.normal(x_shape).astype(mx.float32)
|
||||
dst = mx.random.normal(dst_shape).astype(mx.float32)
|
||||
|
||||
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx)
|
||||
print(f"MLX: {runtime:.3f}ms")
|
||||
|
||||
|
||||
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device):
|
||||
def gather(dst, x, idx, device):
|
||||
dst[idx] = x
|
||||
if device == torch.device("mps"):
|
||||
torch.mps.synchronize()
|
||||
|
||||
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device)
|
||||
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
||||
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
|
||||
|
||||
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
|
||||
print(f"PyTorch: {runtime:.3f}ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("Gather benchmarks.")
|
||||
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.cpu:
|
||||
mx.set_default_device(mx.cpu)
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device("mps")
|
||||
|
||||
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)]
|
||||
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)]
|
||||
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)]
|
||||
|
||||
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
|
||||
print("=" * 20)
|
||||
print(f"X {x_shape}, Indices {idx_shape}")
|
||||
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
|
||||
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)
|
@@ -44,6 +44,13 @@ def time_matmul():
|
||||
time_fn(mx.matmul, a, b)
|
||||
|
||||
|
||||
def time_maximum():
|
||||
a = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
b = mx.random.uniform(shape=(32, 1024, 1024))
|
||||
mx.eval(a, b)
|
||||
time_fn(mx.maximum, a, b)
|
||||
|
||||
|
||||
def time_negative():
|
||||
a = mx.random.uniform(shape=(10000, 1000))
|
||||
mx.eval(a)
|
||||
@@ -101,6 +108,7 @@ if __name__ == "__main__":
|
||||
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_maximum()
|
||||
time_exp()
|
||||
time_negative()
|
||||
time_logsumexp()
|
||||
|
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
@@ -20,3 +20,15 @@ def time_fn(fn, *args, **kwargs):
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
print(f"{msec:.5f} msec")
|
||||
|
||||
|
||||
def measure_runtime(fn, **kwargs):
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
fn(**kwargs)
|
||||
|
||||
tic = time.time()
|
||||
iters = 100
|
||||
for _ in range(iters):
|
||||
fn(**kwargs)
|
||||
return (time.time() - tic) * 1000 / iters
|
||||
|
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
src/python/_autosummary*/
|
||||
src/python/nn/_autosummary*/
|
||||
src/python/optimizers/_autosummary*/
|
||||
|
@@ -1,19 +0,0 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{#{% block methods %}
|
||||
|
||||
{% if methods %}
|
||||
.. rubric:: {{ _('Methods') }}
|
||||
|
||||
.. autosummary::
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members and item != '__init__' %}
|
||||
~{{ name }}.{{ item }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}#}
|
@@ -12,7 +12,7 @@ import mlx.core as mx
|
||||
project = "MLX"
|
||||
copyright = "2023, MLX Contributors"
|
||||
author = "MLX Contributors"
|
||||
version = ".".join(mx.__version__.split()[:-1])
|
||||
version = ".".join(mx.__version__.split(".")[:3])
|
||||
release = version
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
@@ -26,6 +26,7 @@ extensions = [
|
||||
|
||||
python_use_unqualified_type_names = True
|
||||
autosummary_generate = True
|
||||
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
|
||||
|
||||
intersphinx_mapping = {
|
||||
"https://docs.python.org/3": None,
|
||||
|
@@ -35,7 +35,7 @@ However, you work with vector math libraries often and realize that the
|
||||
You would really like the part of your applications that does this operation
|
||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
||||
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
||||
our assumptions on to you, let's also assume that you want to learn how add
|
||||
our assumptions on to you, let's also assume that you want to learn how to add
|
||||
your own implementation for the gradients of your new operation while going
|
||||
over the ins-and-outs of the MLX framework.
|
||||
|
||||
@@ -677,9 +677,9 @@ Let's look at the overall directory structure first.
|
||||
Binding to Python
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings
|
||||
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
|
||||
are already provided, adding our :meth:`axpby` becomes very simple!
|
||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings for
|
||||
components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
|
||||
already provided, adding our :meth:`axpby` is simple!
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
@@ -927,18 +927,18 @@ Results:
|
||||
|
||||
We see some modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations,
|
||||
in :class:`mlx.nn.Module` calls, and also as a part of graph
|
||||
transformations such as :meth:`grad` and :meth:`simplify`!
|
||||
This operation is now good to be used to build other operations, in
|
||||
:class:`mlx.nn.Module` calls, and also as a part of graph transformations like
|
||||
:meth:`grad`!
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx-examples <code>`_.
|
||||
The full example code is available in `mlx <code>`_.
|
||||
|
||||
.. code: `TODO_LINK/extensions`_
|
||||
.. code: `https://github.com/ml-explore/mlx/tree/main/examples/extensions/`_
|
||||
|
||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||
|
@@ -41,6 +41,7 @@ are the CPU and GPU.
|
||||
usage/indexing
|
||||
usage/saving_and_loading
|
||||
usage/function_transforms
|
||||
usage/compile
|
||||
usage/numpy
|
||||
usage/using_streams
|
||||
|
||||
|
@@ -9,9 +9,10 @@ Devices and Streams
|
||||
:toctree: _autosummary
|
||||
|
||||
Device
|
||||
Stream
|
||||
default_device
|
||||
set_default_device
|
||||
Stream
|
||||
default_stream
|
||||
new_stream
|
||||
set_default_stream
|
||||
stream
|
||||
|
@@ -9,3 +9,4 @@ Linear Algebra
|
||||
:toctree: _autosummary
|
||||
|
||||
norm
|
||||
qr
|
||||
|
@@ -180,3 +180,4 @@ In detail:
|
||||
nn/layers
|
||||
nn/functions
|
||||
nn/losses
|
||||
nn/init
|
||||
|
@@ -19,5 +19,6 @@ simple functions.
|
||||
prelu
|
||||
relu
|
||||
selu
|
||||
softshrink
|
||||
silu
|
||||
step
|
||||
|
45
docs/src/python/nn/init.rst
Normal file
45
docs/src/python/nn/init.rst
Normal file
@@ -0,0 +1,45 @@
|
||||
.. _init:
|
||||
|
||||
.. currentmodule:: mlx.nn.init
|
||||
|
||||
Initializers
|
||||
------------
|
||||
|
||||
The ``mlx.nn.init`` package contains commonly used initializers for neural
|
||||
network parameters. Initializers return a function which can be applied to any
|
||||
input :obj:`mlx.core.array` to produce an initialized output.
|
||||
|
||||
For example:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
init_fn = nn.init.uniform()
|
||||
|
||||
# Produces a [2, 2] uniform matrix
|
||||
param = init_fn(mx.zeros((2, 2)))
|
||||
|
||||
To re-initialize all the parameter in an :obj:`mlx.nn.Module` from say a uniform
|
||||
distribution, you can do:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.nn as nn
|
||||
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
|
||||
init_fn = nn.init.uniform(low=-0.1, high=0.1)
|
||||
model.apply(init_fn)
|
||||
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
constant
|
||||
normal
|
||||
uniform
|
||||
identity
|
||||
glorot_normal
|
||||
glorot_uniform
|
||||
he_normal
|
||||
he_uniform
|
@@ -10,6 +10,8 @@ Layers
|
||||
:template: nn-module-template.rst
|
||||
|
||||
ALiBi
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
BatchNorm
|
||||
Conv1d
|
||||
Conv2d
|
||||
@@ -22,6 +24,8 @@ Layers
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
Linear
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
PReLU
|
||||
@@ -33,5 +37,6 @@ Layers
|
||||
Sequential
|
||||
SiLU
|
||||
SinusoidalPositionalEncoding
|
||||
Softshrink
|
||||
Step
|
||||
Transformer
|
||||
|
@@ -18,6 +18,7 @@ Loss Functions
|
||||
kl_div_loss
|
||||
l1_loss
|
||||
log_cosh_loss
|
||||
margin_ranking_loss
|
||||
mse_loss
|
||||
nll_loss
|
||||
smooth_l1_loss
|
||||
|
@@ -11,6 +11,7 @@ Module
|
||||
:toctree: _autosummary
|
||||
|
||||
Module.training
|
||||
Module.state
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
|
@@ -35,6 +35,8 @@ Operations
|
||||
cos
|
||||
cosh
|
||||
dequantize
|
||||
diag
|
||||
diagonal
|
||||
divide
|
||||
divmod
|
||||
equal
|
||||
|
@@ -29,19 +29,8 @@ model's parameters and the **optimizer state**.
|
||||
# Compute the new parameters but also the optimizer state.
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
.. toctree::
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
OptimizerState
|
||||
Optimizer
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
||||
optimizers/optimizer
|
||||
optimizers/common_optimizers
|
||||
optimizers/schedulers
|
||||
|
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
20
docs/src/python/optimizers/common_optimizers.rst
Normal file
@@ -0,0 +1,20 @@
|
||||
.. _common_optimizers:
|
||||
|
||||
Common Optimizers
|
||||
=================
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
:template: optimizers-template.rst
|
||||
|
||||
SGD
|
||||
RMSprop
|
||||
Adagrad
|
||||
Adafactor
|
||||
AdaDelta
|
||||
Adam
|
||||
AdamW
|
||||
Adamax
|
||||
Lion
|
23
docs/src/python/optimizers/optimizer.rst
Normal file
23
docs/src/python/optimizers/optimizer.rst
Normal file
@@ -0,0 +1,23 @@
|
||||
Optimizer
|
||||
=========
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autoclass:: Optimizer
|
||||
|
||||
|
||||
.. rubric:: Attributes
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Optimizer.state
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Optimizer.apply_gradients
|
||||
Optimizer.init
|
||||
Optimizer.update
|
13
docs/src/python/optimizers/schedulers.rst
Normal file
13
docs/src/python/optimizers/schedulers.rst
Normal file
@@ -0,0 +1,13 @@
|
||||
.. _schedulers:
|
||||
|
||||
Schedulers
|
||||
==========
|
||||
|
||||
.. currentmodule:: mlx.optimizers
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
step_decay
|
||||
exponential_decay
|
||||
cosine_decay
|
@@ -9,9 +9,11 @@ Transforms
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
compile
|
||||
disable_compile
|
||||
enable_compile
|
||||
grad
|
||||
value_and_grad
|
||||
jvp
|
||||
vjp
|
||||
vmap
|
||||
simplify
|
||||
|
430
docs/src/usage/compile.rst
Normal file
430
docs/src/usage/compile.rst
Normal file
@@ -0,0 +1,430 @@
|
||||
.. _compile:
|
||||
|
||||
Compilation
|
||||
===========
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX has a :func:`compile` function transformation which compiles computation
|
||||
graphs. Function compilation results in smaller graphs by merging common work
|
||||
and fusing certain operations. In many cases this can lead to big improvements
|
||||
in run-time and memory use.
|
||||
|
||||
Getting started with :func:`compile` is simple, but there are some edge cases
|
||||
that are good to be aware of for more complex graphs and advanced usage.
|
||||
|
||||
Basics of Compile
|
||||
-----------------
|
||||
|
||||
Let's start with a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(-x) + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(2.0)
|
||||
|
||||
# Regular call, no compilation
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(fun(x, y))
|
||||
|
||||
# Compile the function
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Prints: array(2.36788, dtype=float32)
|
||||
print(compiled_fun(x, y))
|
||||
|
||||
The output of both the regular function and the compiled function is the same
|
||||
up to numerical precision.
|
||||
|
||||
The first time you call a compiled function, MLX will build the compute
|
||||
graph, optimize it, and generate and compile code. This can be relatively
|
||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||
function multiple times will not initiate a new compilation. This means you
|
||||
should typically compile functions that you plan to use more than once.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def fun(x, y):
|
||||
return mx.exp(-x) + y
|
||||
|
||||
x = mx.array(1.0)
|
||||
y = mx.array(2.0)
|
||||
|
||||
compiled_fun = mx.compile(fun)
|
||||
|
||||
# Compiled here
|
||||
compiled_fun(x, y)
|
||||
|
||||
# Not compiled again
|
||||
compiled_fun(x, y)
|
||||
|
||||
# Not compiled again
|
||||
mx.compile(fun)(x, y)
|
||||
|
||||
There are some important cases to be aware of that can cause a function to
|
||||
be recompiled:
|
||||
|
||||
* Changing the shape or number of dimensions
|
||||
* Changing the type of any of the inputs
|
||||
* Changing the number of inputs to the function
|
||||
|
||||
In certain cases only some of the compilation stack will be rerun (for
|
||||
example when changing the shapes) and in other cases the full compilation
|
||||
stack will be rerun (for example when changing the types). In general you
|
||||
should avoid compiling functions too frequently.
|
||||
|
||||
Another idiom to watch out for is compiling functions which get created and
|
||||
destroyed frequently. This can happen, for example, when compiling an anonymous
|
||||
function in a loop:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
a = mx.array(1.0)
|
||||
# Don't do this, compiles lambda at each iteration
|
||||
for _ in range(5):
|
||||
mx.compile(lambda x: mx.exp(mx.abs(x)))(a)
|
||||
|
||||
Example Speedup
|
||||
---------------
|
||||
|
||||
The :func:`mlx.nn.gelu` is a nonlinear activation function commonly used with
|
||||
Transformer-based models. The implementation involves several unary and binary
|
||||
element-wise operations:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
If you use this function with small arrays, it will be overhead bound. If you
|
||||
use it with large arrays it will be memory bandwidth bound. However, all of
|
||||
the operations in the ``gelu`` are fusible into a single kernel with
|
||||
:func:`compile`. This can speedup both cases considerably.
|
||||
|
||||
Let's compare the runtime of the regular function versus the compiled
|
||||
function. We'll use the following timing helper which does a warm up and
|
||||
handles synchronization:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
|
||||
def timeit(fun, x):
|
||||
# warm up
|
||||
for _ in range(10):
|
||||
mx.eval(fun(x))
|
||||
|
||||
tic = time.perf_counter()
|
||||
for _ in range(100):
|
||||
mx.eval(fun(x))
|
||||
toc = time.perf_counter()
|
||||
tpi = 1e3 * (toc - tic) / 100
|
||||
print(f"Time per iteration {tpi:.3f} (ms)")
|
||||
|
||||
|
||||
Now make an array, and benchmark both functions:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x = mx.random.uniform(shape=(32, 1000, 4096))
|
||||
timeit(nn.gelu, x)
|
||||
timeit(mx.compile(nn.gelu), x)
|
||||
|
||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||
five times faster.
|
||||
|
||||
.. note::
|
||||
|
||||
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
||||
functions can still be helpful, but won't typically result in as large a
|
||||
speedup as compiling operations that run on the GPU.
|
||||
|
||||
|
||||
Debugging
|
||||
---------
|
||||
|
||||
When a compiled function is first called, it is traced with placeholder
|
||||
inputs. This means you can't evaluate arrays (for example to print their
|
||||
contents) inside compiled functions.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
z = -x
|
||||
print(z) # Crash
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(5.0))
|
||||
|
||||
For debugging, inspecting arrays can be helpful. One way to do that is to
|
||||
globally disable compilation using the :func:`disable_compile` function or
|
||||
``MLX_DISABLE_COMPILE`` flag. For example the following is okay even though
|
||||
``fun`` is compiled:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
z = -x
|
||||
print(z) # Okay
|
||||
return mx.exp(z)
|
||||
|
||||
mx.disable_compile()
|
||||
fun(mx.array(5.0))
|
||||
|
||||
|
||||
Pure Functions
|
||||
--------------
|
||||
|
||||
Compiled functions are intended to be *pure*; that is they should not have side
|
||||
effects. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = []
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z)
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Crash!
|
||||
print(state)
|
||||
|
||||
After the first call of ``fun``, the ``state`` list will hold a placeholder
|
||||
array. The placeholder does not have any data; it is only used to build the
|
||||
computation graph. Printing such an array results in a crash.
|
||||
|
||||
You have two options to deal with this. The first option is to simply return
|
||||
``state`` as an output:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = []
|
||||
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
|
||||
_, state = fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
print(state)
|
||||
|
||||
In some cases returning updated state can be pretty inconvenient. Hence,
|
||||
:func:`compile` has a parameter to capture implicit outputs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from functools import partial
|
||||
|
||||
state = []
|
||||
|
||||
# Tell compile to capture state as an output
|
||||
@partial(mx.compile, outputs=state)
|
||||
def fun(x, y):
|
||||
z = x + y
|
||||
state.append(z)
|
||||
return mx.exp(z), state
|
||||
|
||||
fun(mx.array(1.0), mx.array(2.0))
|
||||
# Prints [array(3, dtype=float32)]
|
||||
print(state)
|
||||
|
||||
This is particularly useful for compiling a function which includes an update
|
||||
to a container of arrays, as is commonly done when training the parameters of a
|
||||
:class:`mlx.nn.Module`.
|
||||
|
||||
Compiled functions will also treat any inputs not in the parameter list as
|
||||
constants. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
return x + state[0]
|
||||
|
||||
# Prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
# Update state
|
||||
state[0] = mx.array(5.0)
|
||||
|
||||
# Still prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
In order to have the change of state reflected in the outputs of ``fun`` you
|
||||
again have two options. The first option is to simply pass ``state`` as input
|
||||
to the function. In some cases this can be pretty inconvenient. Hence,
|
||||
:func:`compile` also has a parameter to capture implicit inputs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from functools import partial
|
||||
state = [mx.array(1.0)]
|
||||
|
||||
# Tell compile to capture state as an input
|
||||
@partial(mx.compile, inputs=state)
|
||||
def fun(x):
|
||||
return x + state[0]
|
||||
|
||||
# Prints array(2, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
# Update state
|
||||
state[0] = mx.array(5.0)
|
||||
|
||||
# Prints array(6, dtype=float32)
|
||||
print(fun(mx.array(1.0)))
|
||||
|
||||
|
||||
Compiling Training Graphs
|
||||
-------------------------
|
||||
|
||||
This section will step through how to use :func:`compile` with a simple example
|
||||
of a common setup: training a model with :obj:`mlx.nn.Module` using an
|
||||
:obj:`mlx.optimizers.Optimizer` with state. We will show how to compile the
|
||||
full forward, backward, and update with :func:`compile`.
|
||||
|
||||
To start, here is the simple example without any compilation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
|
||||
# 4 examples with 10 features each
|
||||
x = mx.random.uniform(shape=(4, 10))
|
||||
|
||||
# 0, 1 targets
|
||||
y = mx.array([0, 1, 0, 1])
|
||||
|
||||
# Simple linear model
|
||||
model = nn.Linear(10, 1)
|
||||
|
||||
# SGD with momentum
|
||||
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||
|
||||
def loss_fn(model, x, y):
|
||||
logits = model(x).squeeze()
|
||||
return nn.losses.binary_cross_entropy(logits, y)
|
||||
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
|
||||
# Perform 10 steps of gradient descent
|
||||
for it in range(10):
|
||||
loss, grads = loss_and_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
To compile the update we can put it all in a function and compile it with the
|
||||
appropriate input and output captures. Here's the same example but compiled:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from functools import partial
|
||||
|
||||
# 4 examples with 10 features each
|
||||
x = mx.random.uniform(shape=(4, 10))
|
||||
|
||||
# 0, 1 targets
|
||||
y = mx.array([0, 1, 0, 1])
|
||||
|
||||
# Simple linear model
|
||||
model = nn.Linear(10, 1)
|
||||
|
||||
# SGD with momentum
|
||||
optimizer = optim.SGD(learning_rate=0.1, momentum=0.8)
|
||||
|
||||
def loss_fn(model, x, y):
|
||||
logits = model(x).squeeze()
|
||||
return nn.losses.binary_cross_entropy(logits, y)
|
||||
|
||||
# The state that will be captured as input and output
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def step(x, y):
|
||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||
loss, grads = loss_and_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
# Perform 10 steps of gradient descent
|
||||
for it in range(10):
|
||||
loss = step(x, y)
|
||||
# Evaluate the model and optimizer state
|
||||
mx.eval(state)
|
||||
print(loss)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
If you are using a module which performs random sampling such as
|
||||
:func:`mlx.nn.Dropout`, make sure you also include ``mx.random.state`` in the
|
||||
``state`` captured by :func:`compile`, i.e. ``state = [model.state,
|
||||
optimizer.state, mx.random.state]``.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
For more examples of compiling full training graphs checkout the `MLX
|
||||
Examples <https://github.com/ml-explore/mlx-examples>`_ GitHub repo.
|
||||
|
||||
Transformations with Compile
|
||||
----------------------------
|
||||
|
||||
In MLX function transformations are composable. You can apply any function
|
||||
transformation to the output of any other function transformation. For more on
|
||||
this, see the documentation on :ref:`function transforms
|
||||
<function_transforms>`.
|
||||
|
||||
Compiling transformed functions works just as expected:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
grad_fn = mx.grad(mx.exp)
|
||||
|
||||
compiled_grad_fn = mx.compile(grad_fn)
|
||||
|
||||
# Prints: array(2.71828, dtype=float32)
|
||||
print(grad_fn(mx.array(1.0)))
|
||||
|
||||
# Also prints: array(2.71828, dtype=float32)
|
||||
print(compiled_grad_fn(mx.array(1.0)))
|
||||
|
||||
.. note::
|
||||
|
||||
In order to compile as much as possible, a transformation of a compiled
|
||||
function will not by default be compiled. To compile the transformed
|
||||
function simply pass it through :func:`compile`.
|
||||
|
||||
You can also compile functions which themselves call compiled functions. A
|
||||
good practice is to compile the outer most function to give :func:`compile`
|
||||
the most opportunity to optimize the computation graph:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@mx.compile
|
||||
def inner(x):
|
||||
return mx.exp(-mx.abs(x))
|
||||
|
||||
def outer(x):
|
||||
inner(inner(x))
|
||||
|
||||
# Compiling the outer function is good to do as it will likely
|
||||
# be faster even though the inner functions are compiled
|
||||
fun = mx.compile(outer)
|
@@ -5,9 +5,12 @@ Function Transforms
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
MLX uses composable function transformations for automatic differentiation and
|
||||
vectorization. The key idea behind composable function transformations is that
|
||||
every transformation returns a function which can be further transformed.
|
||||
MLX uses composable function transformations for automatic differentiation,
|
||||
vectorization, and compute graph optimizations. To see the complete list of
|
||||
function transformations check-out the :ref:`API documentation <transforms>`.
|
||||
|
||||
The key idea behind composable function transformations is that every
|
||||
transformation returns a function which can be further transformed.
|
||||
|
||||
Here is a simple example:
|
||||
|
||||
@@ -36,10 +39,10 @@ Using :func:`grad` on the output of :func:`grad` is always ok. You keep
|
||||
getting higher order derivatives.
|
||||
|
||||
Any of the MLX function transformations can be composed in any order to any
|
||||
depth. To see the complete list of function transformations check-out the
|
||||
:ref:`API documentation <transforms>`. See the following sections for more
|
||||
information on :ref:`automatic differentiaion <auto diff>` and
|
||||
:ref:`automatic vectorization <vmap>`.
|
||||
depth. See the following sections for more information on :ref:`automatic
|
||||
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
|
||||
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.
|
||||
|
||||
|
||||
Automatic Differentiation
|
||||
-------------------------
|
||||
|
@@ -20,7 +20,7 @@ Transforming Compute Graphs
|
||||
|
||||
Lazy evaluation let's us record a compute graph without actually doing any
|
||||
computations. This is useful for function transformations like :func:`grad` and
|
||||
:func:`vmap` and graph optimizations like :func:`simplify`.
|
||||
:func:`vmap` and graph optimizations.
|
||||
|
||||
Currently, MLX does not compile and rerun compute graphs. They are all
|
||||
generated dynamically. However, lazy evaluation makes it much easier to
|
||||
|
@@ -1,4 +1,4 @@
|
||||
cmake_minimum_required(VERSION 3.24)
|
||||
cmake_minimum_required(VERSION 3.27)
|
||||
|
||||
project(mlx_sample_extensions LANGUAGES CXX)
|
||||
|
||||
@@ -63,4 +63,4 @@ target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
endif()
|
||||
|
@@ -3,8 +3,10 @@ target_sources(
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/graph_utils.cpp
|
||||
@@ -19,7 +21,7 @@ target_sources(
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
else()
|
||||
target_sources(
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <functional>
|
||||
|
||||
@@ -47,6 +47,17 @@ array::array(
|
||||
std::move(primitive),
|
||||
inputs)) {}
|
||||
|
||||
array::array(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::move(shape),
|
||||
dtype,
|
||||
std::move(primitive),
|
||||
std::move(inputs))) {}
|
||||
|
||||
std::vector<array> array::make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
@@ -71,6 +82,13 @@ array::array(std::initializer_list<float> data)
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
array::array(std::initializer_list<int> data, Dtype dtype)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(
|
||||
std::vector<int>{static_cast<int>(data.size())},
|
||||
dtype)) {
|
||||
init(data.begin());
|
||||
}
|
||||
|
||||
/* Build an array from a shared buffer */
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
@@ -86,11 +104,13 @@ void array::detach() {
|
||||
s.array_desc_->inputs.clear();
|
||||
s.array_desc_->siblings.clear();
|
||||
s.array_desc_->position = 0;
|
||||
s.array_desc_->depth = 0;
|
||||
s.array_desc_->primitive = nullptr;
|
||||
}
|
||||
array_desc_->inputs.clear();
|
||||
array_desc_->siblings.clear();
|
||||
array_desc_->position = 0;
|
||||
array_desc_->depth = 0;
|
||||
array_desc_->primitive = nullptr;
|
||||
}
|
||||
|
||||
@@ -144,6 +164,14 @@ void array::copy_shared_buffer(const array& other) {
|
||||
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
|
||||
}
|
||||
|
||||
void array::move_shared_buffer(array other) {
|
||||
array_desc_->data = std::move(other.array_desc_->data);
|
||||
array_desc_->strides = other.strides();
|
||||
array_desc_->flags = other.flags();
|
||||
array_desc_->data_size = other.data_size();
|
||||
array_desc_->data_ptr = other.array_desc_->data_ptr;
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
|
||||
: shape(shape), dtype(dtype) {
|
||||
std::tie(size, strides) = cum_prod(shape);
|
||||
@@ -158,10 +186,29 @@ array::ArrayDesc::ArrayDesc(
|
||||
dtype(dtype),
|
||||
primitive(std::move(primitive)),
|
||||
inputs(inputs) {
|
||||
std::tie(size, strides) = cum_prod(shape);
|
||||
for (auto& in : inputs) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
}
|
||||
depth++;
|
||||
}
|
||||
|
||||
array::ArrayDesc::ArrayDesc(
|
||||
std::vector<int>&& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs)
|
||||
: shape(std::move(shape)),
|
||||
dtype(dtype),
|
||||
primitive(std::move(primitive)),
|
||||
inputs(std::move(inputs)) {
|
||||
std::tie(size, strides) = cum_prod(this->shape);
|
||||
for (auto& in : this->inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
depth = std::max(in.graph_depth(), depth);
|
||||
}
|
||||
depth++;
|
||||
}
|
||||
|
||||
array::ArrayIterator::ArrayIterator(const array& arr, int idx)
|
||||
|
58
mlx/array.h
58
mlx/array.h
@@ -41,6 +41,9 @@ class array {
|
||||
/* Special case so empty lists default to float32. */
|
||||
array(std::initializer_list<float> data);
|
||||
|
||||
/* Special case so array({}, type) is an empty array. */
|
||||
array(std::initializer_list<int> data, Dtype dtype);
|
||||
|
||||
template <typename T>
|
||||
array(
|
||||
std::initializer_list<T> data,
|
||||
@@ -121,6 +124,9 @@ class array {
|
||||
template <typename T>
|
||||
T item();
|
||||
|
||||
template <typename T>
|
||||
T item() const;
|
||||
|
||||
struct ArrayIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = size_t;
|
||||
@@ -172,6 +178,12 @@ class array {
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
array(
|
||||
std::vector<int> shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs);
|
||||
|
||||
static std::vector<array> make_arrays(
|
||||
const std::vector<std::vector<int>>& shapes,
|
||||
const std::vector<Dtype>& dtypes,
|
||||
@@ -215,6 +227,11 @@ class array {
|
||||
return *(array_desc_->primitive);
|
||||
};
|
||||
|
||||
/** A shared pointer to the array's primitive. */
|
||||
std::shared_ptr<Primitive>& primitive_ptr() const {
|
||||
return array_desc_->primitive;
|
||||
};
|
||||
|
||||
/** Check if the array has an attached primitive or is a leaf node. */
|
||||
bool has_primitive() const {
|
||||
return array_desc_->primitive != nullptr;
|
||||
@@ -229,6 +246,11 @@ class array {
|
||||
return array_desc_->inputs;
|
||||
}
|
||||
|
||||
/** True indicates the arrays buffer is safe to reuse */
|
||||
bool is_donatable() const {
|
||||
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
const std::vector<array>& siblings() const {
|
||||
return array_desc_->siblings;
|
||||
@@ -251,6 +273,11 @@ class array {
|
||||
return outputs;
|
||||
};
|
||||
|
||||
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
|
||||
uint16_t graph_depth() const {
|
||||
return array_desc_->depth;
|
||||
}
|
||||
|
||||
/** Detach the array from the graph. */
|
||||
void detach();
|
||||
|
||||
@@ -271,6 +298,12 @@ class array {
|
||||
return array_desc_->data->buffer;
|
||||
};
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
std::shared_ptr<Data> data_shared_ptr() const {
|
||||
return array_desc_->data;
|
||||
}
|
||||
// Return a raw pointer to the arrays data
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
@@ -311,6 +344,8 @@ class array {
|
||||
|
||||
void copy_shared_buffer(const array& other);
|
||||
|
||||
void move_shared_buffer(array other);
|
||||
|
||||
void overwrite_descriptor(const array& other) {
|
||||
array_desc_ = other.array_desc_;
|
||||
}
|
||||
@@ -353,6 +388,9 @@ class array {
|
||||
// The arrays position in the output list
|
||||
uint32_t position{0};
|
||||
|
||||
// The depth of the array in the graph.
|
||||
uint16_t depth{0};
|
||||
|
||||
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
|
||||
|
||||
explicit ArrayDesc(
|
||||
@@ -360,12 +398,18 @@ class array {
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
explicit ArrayDesc(
|
||||
std::vector<int>&& shape,
|
||||
Dtype dtype,
|
||||
std::shared_ptr<Primitive> primitive,
|
||||
std::vector<array>&& inputs);
|
||||
};
|
||||
|
||||
// The ArrayDesc contains the details of the materialized array including the
|
||||
// shape, strides, the data type. It also includes
|
||||
// the primitive which knows how to compute the array's data from its inputs
|
||||
// and a the list of array's inputs for the primitive.
|
||||
// and the list of array's inputs for the primitive.
|
||||
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
|
||||
};
|
||||
|
||||
@@ -416,6 +460,18 @@ T array::item() {
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T array::item() const {
|
||||
if (size() != 1) {
|
||||
throw std::invalid_argument("item can only be called on arrays of size 1.");
|
||||
}
|
||||
if (!is_evaled()) {
|
||||
throw std::invalid_argument(
|
||||
"item() const can only be called on evaled arrays");
|
||||
}
|
||||
return *data<T>();
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
void array::init(It src) {
|
||||
set_data(allocator::malloc(size() * size_of(dtype())));
|
||||
|
@@ -46,6 +46,14 @@ inline void matmul_cblas_general(
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
@@ -89,6 +97,14 @@ inline void matmul_bnns_general(
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());
|
||||
|
||||
const BNNSLayerParametersBroadcastMatMul gemm_params{
|
||||
@@ -201,4 +217,4 @@ void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -33,8 +33,12 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT_MULTI(Compiled)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Equal)
|
||||
DEFAULT(Erf)
|
||||
DEFAULT(ErfInv)
|
||||
@@ -50,11 +54,15 @@ DEFAULT(LogicalNot)
|
||||
DEFAULT(LogicalAnd)
|
||||
DEFAULT(LogicalOr)
|
||||
DEFAULT(LogAddExp)
|
||||
DEFAULT(Maximum)
|
||||
DEFAULT(Minimum)
|
||||
DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reshape)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Round)
|
||||
DEFAULT(Scatter)
|
||||
DEFAULT(Sigmoid)
|
||||
@@ -64,27 +72,16 @@ DEFAULT_MULTI(Split)
|
||||
DEFAULT(Sort)
|
||||
DEFAULT(StopGradient)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else if (in.dtype() == int32 && in.flags().contiguous) {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
|
||||
} else if (is_unsigned(in.dtype())) {
|
||||
// No-op for unsigned types
|
||||
out.copy_shared_buffer(in);
|
||||
@@ -137,12 +134,8 @@ void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvacosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -153,12 +146,8 @@ void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvacoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -169,12 +158,8 @@ void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvasinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -185,12 +170,8 @@ void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvasinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -201,12 +182,8 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvatanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -217,12 +194,8 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvatanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -234,30 +207,23 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (in.flags().contiguous) {
|
||||
auto allocfn = [&in, &out]() {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
};
|
||||
// Use accelerate functions if possible
|
||||
if (in.dtype() == float32 && out.dtype() == uint32) {
|
||||
allocfn();
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfixu32(
|
||||
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == float32 && out.dtype() == int32) {
|
||||
allocfn();
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == uint32 && out.dtype() == float32) {
|
||||
allocfn();
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vfltu32(
|
||||
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
} else if (in.dtype() == int32 && out.dtype() == float32) {
|
||||
allocfn();
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
|
||||
return;
|
||||
}
|
||||
@@ -269,12 +235,8 @@ void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvcosf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -285,12 +247,8 @@ void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvcoshf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -335,55 +293,12 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Avoid code duplication with the common backend.
|
||||
struct RemainderFn {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return std::fmod(numerator, denominator);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
};
|
||||
|
||||
void Remainder::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
RemainderFn{},
|
||||
UseDefaultBinaryOp(),
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* a, const auto* b, auto* o, auto n) {
|
||||
int num_el = n;
|
||||
vvremainderf((float*)o, (const float*)a, (const float*)b, &num_el);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, RemainderFn{});
|
||||
}
|
||||
}
|
||||
|
||||
void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
@@ -410,12 +325,8 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
switch (base_) {
|
||||
case Base::e:
|
||||
vvlogf(
|
||||
@@ -439,12 +350,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (is_floating_point(out.dtype())) {
|
||||
@@ -456,47 +363,6 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void Maximum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
if (out.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return (x > y) ? x : y; },
|
||||
UseDefaultBinaryOp(),
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* a, const auto* b, auto* out, int n) {
|
||||
vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Minimum::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (out.dtype() == float32) {
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
[](auto x, auto y) { return (x < y) ? x : y; },
|
||||
UseDefaultBinaryOp(),
|
||||
UseDefaultBinaryOp(),
|
||||
[](const auto* a, const auto* b, auto* out, int n) {
|
||||
vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n);
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
@@ -526,13 +392,8 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else {
|
||||
unary(in, out, [](auto x) { return -x; });
|
||||
}
|
||||
@@ -545,7 +406,13 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() == float32 && a.flags().row_contiguous &&
|
||||
b.flags().row_contiguous) {
|
||||
int size = a.size();
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(a);
|
||||
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(b);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -587,12 +454,8 @@ void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvsinf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -603,12 +466,8 @@ void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvsinhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -619,12 +478,8 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
unary(in, out, [](auto x) { return x * x; });
|
||||
@@ -635,12 +490,8 @@ void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
if (in.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
if (recip_) {
|
||||
vvrsqrtf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
@@ -695,12 +546,8 @@ void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvtanf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
@@ -711,12 +558,8 @@ void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
if (out.dtype() == float32 && in.flags().contiguous) {
|
||||
set_unary_output_data(in, out);
|
||||
int size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
vvtanhf(out.data<float>(), in.data<float>(), &size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
|
@@ -274,7 +274,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
if (x.strides()[x.ndim() - 1] == 1) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
|
@@ -3,6 +3,7 @@ target_sources(
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
@@ -10,10 +11,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
)
|
||||
|
@@ -140,16 +140,34 @@ void Divide::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
struct RemainderFn {
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
std::enable_if_t<std::is_integral_v<T> & !std::is_signed_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
return std::fmod(numerator, denominator);
|
||||
return numerator % denominator;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_integral_v<T>, T> operator()(
|
||||
std::enable_if_t<std::is_integral_v<T> & std::is_signed_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
auto r = numerator % denominator;
|
||||
if (r != 0 && (r < 0 != denominator < 0))
|
||||
r += denominator;
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_integral_v<T>, T> operator()(
|
||||
T numerator,
|
||||
T denominator) {
|
||||
auto r = std::fmod(numerator, denominator);
|
||||
if (r != 0 && (r < 0 != denominator < 0)) {
|
||||
r += denominator;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
complex64_t operator()(complex64_t numerator, complex64_t denominator) {
|
||||
return numerator % denominator;
|
||||
}
|
||||
};
|
||||
@@ -233,14 +251,33 @@ void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||
|
||||
if (is_floating_point(out.dtype())) {
|
||||
binary(a, b, out, [](auto x, auto y) {
|
||||
if (std::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return (x > y) ? x : y;
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Minimum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||
if (is_floating_point(out.dtype())) {
|
||||
binary(a, b, out, [](auto x, auto y) {
|
||||
if (std::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return (x < y) ? x : y;
|
||||
});
|
||||
} else {
|
||||
binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; });
|
||||
}
|
||||
}
|
||||
|
||||
void Multiply::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
@@ -40,29 +39,83 @@ void set_binary_op_output_data(
|
||||
const array& a,
|
||||
const array& b,
|
||||
array& out,
|
||||
BinaryOpType bopt) {
|
||||
BinaryOpType bopt,
|
||||
bool donate_with_move = false) {
|
||||
switch (bopt) {
|
||||
case ScalarScalar:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||
break;
|
||||
case ScalarVector:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
|
||||
b.data_size(),
|
||||
b.strides(),
|
||||
b.flags());
|
||||
}
|
||||
break;
|
||||
case VectorScalar:
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
}
|
||||
break;
|
||||
case VectorVector:
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
}
|
||||
break;
|
||||
case General:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
if (a.is_donatable() && a.flags().row_contiguous &&
|
||||
a.itemsize() == out.itemsize() && a.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(a);
|
||||
} else {
|
||||
out.copy_shared_buffer(a);
|
||||
}
|
||||
} else if (
|
||||
b.is_donatable() && b.flags().row_contiguous &&
|
||||
b.itemsize() == out.itemsize() && b.size() == out.size()) {
|
||||
if (donate_with_move) {
|
||||
out.move_shared_buffer(b);
|
||||
} else {
|
||||
out.copy_shared_buffer(b);
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
59
mlx/backend/common/compiled.cpp
Normal file
59
mlx/backend/common/compiled.cpp
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <queue>
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Build the real tape
|
||||
std::pair<std::queue<array>, std::vector<array>> trace_to_real(
|
||||
const std::vector<array>& trace_tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs) {
|
||||
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||
}
|
||||
std::queue<array> tape;
|
||||
for (auto& a : trace_tape) {
|
||||
// Find real inputs
|
||||
std::vector<array> real_inputs;
|
||||
for (auto& in : a.inputs()) {
|
||||
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||
}
|
||||
tape.push(
|
||||
array(a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs)));
|
||||
trace_to_real.insert({a.id(), tape.back()});
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
for (auto& o : trace_outputs) {
|
||||
outputs.push_back(trace_to_real.at(o.id()));
|
||||
}
|
||||
return {tape, outputs};
|
||||
}
|
||||
|
||||
void Compiled::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Make the a real tape from the tracers
|
||||
auto [tape, real_outputs] = trace_to_real(tape_, inputs_, outputs_, inputs);
|
||||
|
||||
// Run the tape
|
||||
while (!tape.empty()) {
|
||||
auto a = std::move(tape.front());
|
||||
tape.pop();
|
||||
auto outputs = a.outputs();
|
||||
a.primitive().eval_cpu(a.inputs(), outputs);
|
||||
a.detach();
|
||||
}
|
||||
|
||||
// Copy results into outputs
|
||||
for (int o = 0; o < real_outputs.size(); ++o) {
|
||||
outputs[o].copy_shared_buffer(real_outputs[o]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -3,7 +3,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
@@ -289,11 +289,16 @@ void copy(const array& src, array& dst, CopyType ctype) {
|
||||
// Allocate the output
|
||||
switch (ctype) {
|
||||
case CopyType::Vector:
|
||||
dst.set_data(
|
||||
allocator::malloc_or_wait(src.data_size() * dst.itemsize()),
|
||||
src.data_size(),
|
||||
src.strides(),
|
||||
src.flags());
|
||||
if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
|
||||
dst.copy_shared_buffer(src);
|
||||
} else {
|
||||
auto size = src.data_size();
|
||||
dst.set_data(
|
||||
allocator::malloc_or_wait(size * dst.itemsize()),
|
||||
size,
|
||||
src.strides(),
|
||||
src.flags());
|
||||
}
|
||||
break;
|
||||
case CopyType::Scalar:
|
||||
case CopyType::General:
|
||||
|
@@ -1,11 +1,13 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <vecLib/cblas_new.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
@@ -39,12 +41,16 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT_MULTI(Compiled)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(Remainder)
|
||||
DEFAULT(Equal)
|
||||
@@ -74,6 +80,7 @@ DEFAULT(NotEqual)
|
||||
DEFAULT(Pad)
|
||||
DEFAULT(Partition)
|
||||
DEFAULT(Power)
|
||||
DEFAULT_MULTI(QRF)
|
||||
DEFAULT(QuantizedMatmul)
|
||||
DEFAULT(RandomBits)
|
||||
DEFAULT(Reduce)
|
||||
@@ -96,7 +103,6 @@ DEFAULT(Subtract)
|
||||
DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -126,6 +132,13 @@ inline void matmul_common_general(
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
size_t K = a.shape(-1);
|
||||
if (M == 0 || N == 0) {
|
||||
return;
|
||||
}
|
||||
if (K == 0) {
|
||||
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
cblas_sgemm(
|
||||
|
@@ -232,22 +232,38 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void CustomVJP::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||
i++, j++) {
|
||||
outputs[i].copy_shared_buffer(inputs[j]);
|
||||
}
|
||||
}
|
||||
|
||||
void Depends::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
outputs[i].copy_shared_buffer(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void Erf::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<float>(in, out, [](auto x) { return std::erf(x); });
|
||||
break;
|
||||
case float16:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<float16_t>(in, out, [](auto x) {
|
||||
return static_cast<float16_t>(std::erf(static_cast<float>(x)));
|
||||
});
|
||||
break;
|
||||
case bfloat16:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<bfloat16_t>(in, out, [](auto x) {
|
||||
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
|
||||
});
|
||||
@@ -264,17 +280,14 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
|
||||
const auto& in = inputs[0];
|
||||
switch (out.dtype()) {
|
||||
case float32:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<float>(in, out, [](auto x) { return erfinv(x); });
|
||||
break;
|
||||
case float16:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<float16_t>(in, out, [](auto x) {
|
||||
return static_cast<float16_t>(erfinv(static_cast<float>(x)));
|
||||
});
|
||||
break;
|
||||
case bfloat16:
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
unary_op<bfloat16_t>(in, out, [](auto x) {
|
||||
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
|
||||
});
|
||||
|
153
mlx/backend/common/qrf.cpp
Normal file
153
mlx/backend/common/qrf.cpp
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
template <typename T>
|
||||
struct lpack;
|
||||
|
||||
template <>
|
||||
struct lpack<float> {
|
||||
static void xgeqrf(
|
||||
const int* m,
|
||||
const int* n,
|
||||
float* a,
|
||||
const int* lda,
|
||||
float* tau,
|
||||
float* work,
|
||||
const int* lwork,
|
||||
int* info) {
|
||||
sgeqrf_(m, n, a, lda, tau, work, lwork, info);
|
||||
}
|
||||
static void xorgqr(
|
||||
const int* m,
|
||||
const int* n,
|
||||
const int* k,
|
||||
float* a,
|
||||
const int* lda,
|
||||
const float* tau,
|
||||
float* work,
|
||||
const int* lwork,
|
||||
int* info) {
|
||||
sorgqr_(m, n, k, a, lda, tau, work, lwork, info);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void qrf_impl(const array& a, array& q, array& r) {
|
||||
const int M = a.shape(-2);
|
||||
const int N = a.shape(-1);
|
||||
const int lda = std::max(M, N);
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
int num_reflectors = std::min(M, N);
|
||||
auto tau =
|
||||
allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors);
|
||||
|
||||
// Copy A to inplace input and make it col-contiguous
|
||||
array in(a.shape(), float32, nullptr, {});
|
||||
auto flags = in.flags();
|
||||
|
||||
// Copy the input to be column contiguous
|
||||
flags.col_contiguous = num_matrices == 1;
|
||||
flags.row_contiguous = false;
|
||||
std::vector<size_t> strides = in.strides();
|
||||
strides[in.ndim() - 2] = 1;
|
||||
strides[in.ndim() - 1] = M;
|
||||
in.set_data(
|
||||
allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags);
|
||||
copy_inplace(a, in, CopyType::GeneralGeneral);
|
||||
|
||||
T optimal_work;
|
||||
int lwork = -1;
|
||||
int info;
|
||||
|
||||
// Compute workspace size
|
||||
lpack<T>::xgeqrf(
|
||||
&M, &N, nullptr, &lda, nullptr, &optimal_work, &lwork, &info);
|
||||
|
||||
// Update workspace size
|
||||
lwork = optimal_work;
|
||||
auto work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Solve
|
||||
lpack<T>::xgeqrf(
|
||||
&M,
|
||||
&N,
|
||||
in.data<float>() + M * N * i,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
}
|
||||
allocator::free(work);
|
||||
|
||||
r.set_data(allocator::malloc_or_wait(r.nbytes()));
|
||||
copy_inplace(in, r, CopyType::General);
|
||||
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Zero lower triangle
|
||||
for (int j = 0; j < r.shape(-2); ++j) {
|
||||
for (int k = 0; k < j; ++k) {
|
||||
r.data<T>()[i * N * M + j * N + k] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get work size
|
||||
lwork = -1;
|
||||
lpack<T>::xorgqr(
|
||||
&M,
|
||||
&N,
|
||||
&num_reflectors,
|
||||
nullptr,
|
||||
&lda,
|
||||
nullptr,
|
||||
&optimal_work,
|
||||
&lwork,
|
||||
&info);
|
||||
lwork = optimal_work;
|
||||
work = allocator::malloc_or_wait(sizeof(T) * lwork);
|
||||
|
||||
// Loop over matrices
|
||||
for (int i = 0; i < num_matrices; ++i) {
|
||||
// Compute Q
|
||||
lpack<T>::xorgqr(
|
||||
&M,
|
||||
&N,
|
||||
&num_reflectors,
|
||||
in.data<float>() + M * N * i,
|
||||
&lda,
|
||||
static_cast<T*>(tau.raw_ptr()) + num_reflectors * i,
|
||||
static_cast<T*>(work.raw_ptr()),
|
||||
&lwork,
|
||||
&info);
|
||||
}
|
||||
|
||||
q.set_data(allocator::malloc_or_wait(q.nbytes()));
|
||||
copy_inplace(in, q, CopyType::General);
|
||||
|
||||
// Cleanup
|
||||
allocator::free(work);
|
||||
allocator::free(tau);
|
||||
}
|
||||
|
||||
void QRF::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
|
||||
if (!(inputs[0].dtype() == float32)) {
|
||||
throw std::runtime_error("[QRF::eval] only supports float32.");
|
||||
}
|
||||
qrf_impl<float>(inputs[0], outputs[0], outputs[1]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -119,6 +118,12 @@ void _qmm_dispatch_typed(
|
||||
switch (bits) {
|
||||
case 2: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
@@ -135,6 +140,12 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
case 4: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
@@ -151,6 +162,12 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
case 8: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
|
14
mlx/backend/common/rope.cpp
Normal file
14
mlx/backend/common/rope.cpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RoPE::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -53,7 +53,12 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
auto check_input = [](array x) {
|
||||
if (x.strides().back() == 1) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
|
@@ -64,15 +64,24 @@ struct RoundOp {
|
||||
}
|
||||
};
|
||||
|
||||
void set_unary_output_data(const array& in, array& out) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
auto size = in.data_size();
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(size * out.itemsize()),
|
||||
size,
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
void unary_op(const array& a, array& out, Op op) {
|
||||
const T* a_ptr = a.data<T>();
|
||||
if (a.flags().contiguous) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
|
||||
a.data_size(),
|
||||
a.strides(),
|
||||
a.flags());
|
||||
set_unary_output_data(a, out);
|
||||
T* dst = out.data<T>();
|
||||
for (size_t i = 0; i < a.data_size(); ++i) {
|
||||
dst[i] = op(a_ptr[i]);
|
||||
|
@@ -1,7 +1,28 @@
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
COMMAND /bin/bash
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_C_COMPILER}
|
||||
${CMAKE_SOURCE_DIR}
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/compiled_preamble.h
|
||||
kernels/unary.h
|
||||
kernels/binary.h
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
compiled_preamble
|
||||
DEPENDS compiled_preamble.cpp
|
||||
)
|
||||
|
||||
add_dependencies(mlx compiled_preamble)
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
@@ -11,10 +32,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
|
484
mlx/backend/metal/compiled.cpp
Normal file
484
mlx/backend/metal/compiled.cpp
Normal file
@@ -0,0 +1,484 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/metal/compiled_preamble.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
inline bool is_static_cast(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
inline auto get_type_string(Dtype d) {
|
||||
switch (d) {
|
||||
case float32:
|
||||
return "float";
|
||||
case float16:
|
||||
return "half";
|
||||
case bfloat16:
|
||||
return "bfloat16_t";
|
||||
case bool_:
|
||||
return "bool";
|
||||
case int8:
|
||||
return "int8_t";
|
||||
case int16:
|
||||
return "int16_t";
|
||||
case int32:
|
||||
return "int32_t";
|
||||
case int64:
|
||||
return "int64_t";
|
||||
case uint8:
|
||||
return "uint8_t";
|
||||
case uint16:
|
||||
return "uint16_t";
|
||||
case uint32:
|
||||
return "uint32_t";
|
||||
case uint64:
|
||||
return "uint64_t";
|
||||
default: {
|
||||
std::ostringstream msg;
|
||||
msg << "Unsupported compilation type " << d;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print_float_constant(std::ostream& os, const array& x) {
|
||||
auto old_precision = os.precision();
|
||||
os << std::setprecision(std::numeric_limits<float>::digits10 + 1)
|
||||
<< x.item<T>() << std::setprecision(old_precision);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print_int_constant(std::ostream& os, const array& x) {
|
||||
os << x.item<T>();
|
||||
}
|
||||
|
||||
void print_constant(std::ostream& os, const array& x) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
return print_float_constant<float>(os, x);
|
||||
case float16:
|
||||
return print_float_constant<float16_t>(os, x);
|
||||
case bfloat16:
|
||||
return print_float_constant<bfloat16_t>(os, x);
|
||||
case int8:
|
||||
return print_int_constant<int8_t>(os, x);
|
||||
case int16:
|
||||
return print_int_constant<int16_t>(os, x);
|
||||
case int32:
|
||||
return print_int_constant<int32_t>(os, x);
|
||||
case int64:
|
||||
return print_int_constant<int64_t>(os, x);
|
||||
case uint8:
|
||||
return print_int_constant<uint8_t>(os, x);
|
||||
case uint16:
|
||||
return print_int_constant<uint16_t>(os, x);
|
||||
case uint32:
|
||||
return print_int_constant<uint32_t>(os, x);
|
||||
case uint64:
|
||||
return print_int_constant<uint64_t>(os, x);
|
||||
case bool_:
|
||||
os << std::boolalpha << x.item<bool>();
|
||||
return;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported constant type");
|
||||
}
|
||||
}
|
||||
|
||||
inline std::string build_lib_name(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids) {
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
// The primitives describing the tape. For unary and binary primitives this
|
||||
// must be enough to describe the full computation.
|
||||
for (auto& a : tape) {
|
||||
a.primitive().print(os);
|
||||
}
|
||||
os << "_";
|
||||
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
os << "C";
|
||||
print_constant(constant_hasher, x);
|
||||
} else {
|
||||
os << ((x.size() == 1) ? "S" : "V");
|
||||
}
|
||||
}
|
||||
os << "_";
|
||||
for (auto& x : inputs) {
|
||||
if (constant_ids.find(x.id()) != constant_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
os << kindof(x.dtype()) << x.itemsize();
|
||||
}
|
||||
os << "_" << std::hash<std::string>{}(constant_hasher.str());
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
inline void build_kernel(
|
||||
std::ostream& os,
|
||||
const std::string& kernel_name,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& tape,
|
||||
const std::unordered_set<uintptr_t>& constant_ids,
|
||||
bool contiguous,
|
||||
int ndim,
|
||||
bool dynamic_dims) {
|
||||
// All outputs should have the exact same shape and will be row contiguous
|
||||
auto output_shape = outputs[0].shape();
|
||||
auto output_strides = outputs[0].strides();
|
||||
|
||||
// Constants are scalars that are captured by value and cannot change
|
||||
auto is_constant = [&constant_ids](const array& x) {
|
||||
return constant_ids.find(x.id()) != constant_ids.end();
|
||||
};
|
||||
|
||||
// For scalar we shouldn't do the indexing things, just read at 0
|
||||
auto is_scalar = [](const array& x) { return x.size() == 1; };
|
||||
|
||||
NodeNamer namer;
|
||||
bool add_indices = false;
|
||||
int cnt = 0;
|
||||
|
||||
// Start the kernel
|
||||
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
|
||||
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
|
||||
|
||||
// Add the input arguments
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
// Skip constants from the input list
|
||||
if (is_constant(x)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Scalars and contiguous need no strides
|
||||
if (is_scalar(x) || contiguous) {
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
} else {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl
|
||||
<< " constant const size_t* " << xname << "_strides [[buffer("
|
||||
<< cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
// Add output strides and shape to extract the indices.
|
||||
if (!contiguous) {
|
||||
os << " constant const size_t* output_strides [[buffer(" << cnt++
|
||||
<< ")]]," << std::endl
|
||||
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
|
||||
<< std::endl;
|
||||
}
|
||||
if (dynamic_dims) {
|
||||
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// The thread index in the whole grid
|
||||
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
|
||||
<< " uint3 grid [[threads_per_grid]]) {" << std::endl
|
||||
<< " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);"
|
||||
<< std::endl;
|
||||
|
||||
// Extract the indices per axis to individual uints if we have arrays that
|
||||
// are broadcasted or transposed
|
||||
if (add_indices) {
|
||||
if (!dynamic_dims) {
|
||||
if (ndim == 1) {
|
||||
os << " uint index_0 = pos.x;" << std::endl;
|
||||
} else if (ndim == 2) {
|
||||
os << " uint index_0 = pos.y;" << std::endl
|
||||
<< " uint index_1 = pos.x;" << std::endl;
|
||||
} else if (ndim == 3) {
|
||||
os << " uint index_0 = pos.z;" << std::endl
|
||||
<< " uint index_1 = pos.y;" << std::endl
|
||||
<< " uint index_2 = pos.x;" << std::endl;
|
||||
} else {
|
||||
for (int i = 0; i < ndim - 2; i++) {
|
||||
os << " uint index_" << i << " = (index / uint(output_strides[" << i
|
||||
<< "])) % output_shape[" << i << "];" << std::endl;
|
||||
}
|
||||
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
|
||||
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];" << std::endl;
|
||||
} else if (contiguous) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];" << std::endl;
|
||||
} else if (!dynamic_dims) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[";
|
||||
os << "index_0 * " << xname << "_strides[0]";
|
||||
for (int i = 1; i < ndim; i++) {
|
||||
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
|
||||
}
|
||||
os << "];" << std::endl;
|
||||
} else {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[elem_to_loc(index, output_shape, " << xname
|
||||
<< "_strides, ndim)];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Actually write the computation
|
||||
for (auto& x : tape) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
|
||||
<< " = ";
|
||||
if (is_static_cast(x.primitive())) {
|
||||
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
|
||||
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
|
||||
} else {
|
||||
x.primitive().print(os);
|
||||
os << "()(";
|
||||
for (int i = 0; i < x.inputs().size() - 1; i++) {
|
||||
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
|
||||
}
|
||||
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the outputs from tmps
|
||||
for (auto& x : outputs) {
|
||||
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
|
||||
<< ";" << std::endl;
|
||||
}
|
||||
|
||||
// Finish the kernel
|
||||
os << "}" << std::endl;
|
||||
|
||||
if (cnt > 31) {
|
||||
std::ostringstream msg;
|
||||
msg << "[compile] Too many inputs/outputs fused in the Metal Compile "
|
||||
<< "primitive which exhausted the available argument buffers for "
|
||||
<< "the kernel. Please file an issue with the function that results "
|
||||
<< "in this error. The name of the kernel is '" << kernel_name << "'";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
void Compiled::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Make the name for the kernel library
|
||||
if (kernel_lib_.empty()) {
|
||||
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
|
||||
}
|
||||
|
||||
// Get the kernel if someone else built it already
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto lib = d.get_library(kernel_lib_);
|
||||
|
||||
// If not we have to build it ourselves
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel;
|
||||
kernel << metal::get_kernel_preamble() << std::endl;
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ true,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ false);
|
||||
for (int i = 1; i < 8; i++) {
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_" + std::to_string(i),
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ i,
|
||||
/* dynamic_dims = */ false);
|
||||
}
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_strided_dynamic",
|
||||
inputs_,
|
||||
outputs_,
|
||||
tape_,
|
||||
constant_ids_,
|
||||
/* contiguous = */ false,
|
||||
/* ndim = */ 0,
|
||||
/* dynamic_dims = */ true);
|
||||
|
||||
kernel_source_ = kernel.str();
|
||||
lib = d.get_library(kernel_lib_, kernel_source_);
|
||||
}
|
||||
|
||||
// Allocate space for the outputs
|
||||
for (auto& out : outputs) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
// Figure out which kernel we are using
|
||||
auto& output_shape = outputs[0].shape();
|
||||
bool contiguous = true;
|
||||
for (auto& x : inputs) {
|
||||
if ((!x.flags().row_contiguous || x.shape() != output_shape) &&
|
||||
x.size() > 1) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Collapse contiguous dims to route to a faster kernel if possible. Also
|
||||
// handle all broadcasting.
|
||||
std::vector<std::vector<size_t>> initial_strides;
|
||||
initial_strides.push_back(outputs[0].strides());
|
||||
std::vector<int> shape;
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
if (!contiguous) {
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
// Skip constants.
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
|
||||
// Skip scalar inputs.
|
||||
if (x.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Broadcast the inputs to the output shape.
|
||||
std::vector<size_t> xstrides;
|
||||
int j = 0;
|
||||
for (; j < output_shape.size() - x.ndim(); j++) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < x.ndim(); i++, j++) {
|
||||
if (x.shape(i) == 1) {
|
||||
if (output_shape[j] == 1) {
|
||||
xstrides.push_back(outputs[0].strides()[j]);
|
||||
} else {
|
||||
xstrides.push_back(0);
|
||||
}
|
||||
} else {
|
||||
xstrides.push_back(x.strides()[i]);
|
||||
}
|
||||
}
|
||||
initial_strides.push_back(std::move(xstrides));
|
||||
}
|
||||
std::tie(shape, strides) =
|
||||
collapse_contiguous_dims(output_shape, initial_strides);
|
||||
}
|
||||
|
||||
// Get the kernel from the lib
|
||||
int ndim = shape.size();
|
||||
bool dynamic = ndim >= 8;
|
||||
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
|
||||
if (!contiguous) {
|
||||
if (dynamic) {
|
||||
kernel_name += "dynamic";
|
||||
} else {
|
||||
kernel_name += std::to_string(shape.size());
|
||||
}
|
||||
}
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& x = inputs[i];
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
if (!contiguous && x.size() > 1) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
strides[stride_idx].size() * sizeof(size_t),
|
||||
cnt++);
|
||||
stride_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
set_array_buffer(compute_encoder, x, cnt++);
|
||||
}
|
||||
|
||||
// Put the output shape and strides in
|
||||
if (!contiguous) {
|
||||
compute_encoder->setBytes(
|
||||
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
|
||||
}
|
||||
|
||||
// Put the number of dims in if it is dynamic
|
||||
if (dynamic) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), cnt++);
|
||||
}
|
||||
|
||||
// Launch the kernel
|
||||
if (contiguous) {
|
||||
size_t nthreads = outputs[0].size();
|
||||
MTL::Size grid_dims(nthreads, 1, 1);
|
||||
MTL::Size group_dims(
|
||||
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = outputs[0].size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
9
mlx/backend/metal/compiled_preamble.h
Normal file
9
mlx/backend/metal/compiled_preamble.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* get_kernel_preamble();
|
||||
|
||||
}
|
@@ -2,7 +2,6 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
|
@@ -12,11 +12,15 @@ namespace mlx::core {
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
@@ -67,7 +71,8 @@ void copy_gpu_inplace(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib>
|
||||
@@ -26,7 +26,8 @@ static constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0));
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||
?: MTL::CreateSystemDefaultDevice();
|
||||
if (!device) {
|
||||
throw std::runtime_error("Failed to load device");
|
||||
}
|
||||
@@ -214,15 +215,6 @@ MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
|
||||
return eit->second;
|
||||
}
|
||||
|
||||
MTL::ArgumentEncoder* Device::argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
|
||||
// NB array here is already autoreleased but the returned argument
|
||||
// encoder is owned by the caller and must be released/autoreleased
|
||||
NS::Array* arg_desc_arr = NS::Array::array(
|
||||
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
|
||||
return device_->newArgumentEncoder(arg_desc_arr);
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
@@ -242,37 +234,127 @@ void Device::register_library(
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& name,
|
||||
const std::string& lib_name /* = "mlx" */) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Prepare new kernel
|
||||
|
||||
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib;
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
||||
mtl_lib = it->second;
|
||||
} else { // Look for metallib alongside library
|
||||
register_library(lib_name);
|
||||
mtl_lib = library_map_[lib_name];
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
auto ns_code =
|
||||
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
|
||||
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
|
||||
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load build metal library from source"
|
||||
<< "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_lib = device_->newLibrary(desc, &error);
|
||||
|
||||
// Throw error if unable to compile library
|
||||
if (!mtl_lib) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load build stitched metal library"
|
||||
<< "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function_(
|
||||
const std::string& name,
|
||||
MTL::Library* mtl_lib) {
|
||||
// Pull kernel from library
|
||||
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
|
||||
auto mtl_function = mtl_lib->newFunction(ns_name);
|
||||
|
||||
return mtl_function;
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function_(
|
||||
const std::string& name,
|
||||
const std::string& specialized_name,
|
||||
const MTLFCList& func_consts,
|
||||
MTL::Library* mtl_lib) {
|
||||
if (func_consts.empty() && (specialized_name == name)) {
|
||||
return get_function_(name, mtl_lib);
|
||||
}
|
||||
|
||||
// Prepare function constants
|
||||
auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();
|
||||
|
||||
for (auto [value, type, index] : func_consts) {
|
||||
mtl_func_consts->setConstantValue(value, type, index);
|
||||
}
|
||||
|
||||
// Prepare function desc
|
||||
auto desc = MTL::FunctionDescriptor::functionDescriptor();
|
||||
desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));
|
||||
desc->setSpecializedName(
|
||||
NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));
|
||||
desc->setConstantValues(mtl_func_consts);
|
||||
|
||||
// Pull kernel from library
|
||||
NS::Error* error = nullptr;
|
||||
auto mtl_function = mtl_lib->newFunction(desc, &error);
|
||||
|
||||
// Throw error if unable to build metal function
|
||||
if (!mtl_function) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load function " << name << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
mtl_func_consts->release();
|
||||
desc->release();
|
||||
|
||||
return mtl_function;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel_(
|
||||
const std::string& name,
|
||||
const MTL::Function* mtl_function) {
|
||||
// Compile kernel to compute pipeline
|
||||
NS::Error* error = nullptr;
|
||||
MTL::ComputePipelineState* kernel;
|
||||
|
||||
if (mtl_function) {
|
||||
kernel = device_->newComputePipelineState(mtl_function, &error);
|
||||
mtl_function->release();
|
||||
}
|
||||
|
||||
// Throw error if unable to compile metal function
|
||||
if (!mtl_function || !kernel) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||
@@ -282,11 +364,175 @@ MTL::ComputePipelineState* Device::get_kernel(
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({name, kernel});
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel_(
|
||||
const std::string& name,
|
||||
const MTL::Function* mtl_function,
|
||||
const MTL::LinkedFunctions* linked_functions) {
|
||||
// Check inputs
|
||||
if (!linked_functions) {
|
||||
return get_kernel_(name, mtl_function);
|
||||
}
|
||||
|
||||
if (!mtl_function) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Prepare compute pipeline state descriptor
|
||||
auto desc = MTL::ComputePipelineDescriptor::alloc()->init();
|
||||
desc->setComputeFunction(mtl_function);
|
||||
desc->setLinkedFunctions(linked_functions);
|
||||
|
||||
// Compile kernel to compute pipeline
|
||||
NS::Error* error = nullptr;
|
||||
auto kernel = device_->newComputePipelineState(
|
||||
desc, MTL::PipelineOptionNone, nullptr, &error);
|
||||
|
||||
// Throw error if unable to compile metal function
|
||||
if (!kernel) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(const std::string& name) {
|
||||
auto it = library_map_.find(name);
|
||||
return (it != library_map_.end()) ? it->second : nullptr;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const std::string& source,
|
||||
bool cache /* = true */) {
|
||||
if (cache) {
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
auto mtl_lib = get_library_(source);
|
||||
|
||||
if (cache) {
|
||||
library_map_.insert({name, mtl_lib});
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Library* Device::get_library(
|
||||
const std::string& name,
|
||||
const MTL::StitchedLibraryDescriptor* desc,
|
||||
bool cache /* = true */) {
|
||||
if (cache) {
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
auto mtl_lib = get_library_(desc);
|
||||
|
||||
if (cache) {
|
||||
library_map_.insert({name, mtl_lib});
|
||||
}
|
||||
|
||||
return mtl_lib;
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& specialized_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */) {
|
||||
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
|
||||
}
|
||||
|
||||
MTL::Function* Device::get_function(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name /* = "mlx" */,
|
||||
const std::string& specialized_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */) {
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
||||
|
||||
return get_function(base_name, mtl_lib, specialized_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::LinkedFunctions* Device::get_linked_functions_(
|
||||
const std::vector<MTL::Function*>& funcs) {
|
||||
if (funcs.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
|
||||
|
||||
std::vector<NS::Object*> objs(funcs.size());
|
||||
for (int i = 0; i < funcs.size(); i++) {
|
||||
objs[i] = funcs[i];
|
||||
}
|
||||
|
||||
NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size());
|
||||
|
||||
lfuncs->setPrivateFunctions(funcs_arr);
|
||||
|
||||
return lfuncs;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
|
||||
// Look for cached kernel
|
||||
const auto& kname = hash_name.empty() ? base_name : hash_name;
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Pull kernel from library
|
||||
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
|
||||
|
||||
// Compile kernel to compute pipeline
|
||||
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
|
||||
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
|
||||
mtl_function->release();
|
||||
mtl_linked_funcs->release();
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({kname, kernel});
|
||||
return kernel;
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name /* = "mlx" */,
|
||||
const std::string& hash_name /* = "" */,
|
||||
const MTLFCList& func_consts /* = {} */,
|
||||
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
|
||||
// Look for cached kernel
|
||||
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
|
||||
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib = get_library_cache_(lib_name);
|
||||
|
||||
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device) {
|
||||
static Device metal_device;
|
||||
return metal_device;
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -31,6 +31,9 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
return mtllib_path;
|
||||
}
|
||||
|
||||
using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
class Device {
|
||||
public:
|
||||
Device();
|
||||
@@ -59,14 +62,73 @@ class Device {
|
||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||
get_colocated_mtllib_path);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
MTL::Library* get_library(const std::string& name);
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const std::string& lib_name = "mlx");
|
||||
const std::string& source_string,
|
||||
bool cache = true);
|
||||
|
||||
MTL::Library* get_library(
|
||||
const std::string& name,
|
||||
const MTL::StitchedLibraryDescriptor* desc,
|
||||
bool cache = true);
|
||||
|
||||
MTL::Function* get_function(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& specialized_name = "",
|
||||
const MTLFCList& func_consts = {});
|
||||
|
||||
MTL::Function* get_function(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name = "mlx",
|
||||
const std::string& specialized_name = "",
|
||||
const MTLFCList& func_consts = {});
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
MTL::Library* mtl_lib,
|
||||
const std::string& hash_name = "",
|
||||
const MTLFCList& func_consts = {},
|
||||
const std::vector<MTL::Function*>& linked_functions = {});
|
||||
|
||||
MTL::ComputePipelineState* get_kernel(
|
||||
const std::string& base_name,
|
||||
const std::string& lib_name = "mlx",
|
||||
const std::string& hash_name = "",
|
||||
const MTLFCList& func_consts = {},
|
||||
const std::vector<MTL::Function*>& linked_functions = {});
|
||||
|
||||
MTL::ArgumentEncoder* argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||
|
||||
private:
|
||||
MTL::Library* get_library_cache_(const std::string& name);
|
||||
|
||||
MTL::Library* get_library_(const std::string& source_string);
|
||||
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
|
||||
|
||||
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
|
||||
|
||||
MTL::Function* get_function_(
|
||||
const std::string& name,
|
||||
const std::string& specialized_name,
|
||||
const MTLFCList& func_consts,
|
||||
MTL::Library* mtl_lib);
|
||||
|
||||
MTL::LinkedFunctions* get_linked_functions_(
|
||||
const std::vector<MTL::Function*>& funcs);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel_(
|
||||
const std::string& name,
|
||||
const MTL::Function* mtl_function);
|
||||
|
||||
MTL::ComputePipelineState* get_kernel_(
|
||||
const std::string& name,
|
||||
const MTL::Function* mtl_function,
|
||||
const MTL::LinkedFunctions* linked_functions);
|
||||
|
||||
MTL::Device* device_;
|
||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
@@ -39,114 +39,75 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
||||
if (idx_ndim <= 1) {
|
||||
kname << "_" << idx_ndim;
|
||||
}
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
for (auto s : slice_sizes_) {
|
||||
slice_size *= s;
|
||||
}
|
||||
|
||||
size_t ndim = src.ndim();
|
||||
size_t nthreads = out.size();
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
// Launch 2D grid of threads: indices x slice
|
||||
size_t dim0 = out.size() / slice_size;
|
||||
size_t dim1 = slice_size;
|
||||
auto group_dims = get_block_dims(dim0, dim1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, 1);
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
// `Indices` struct in kernels/indexing.metal
|
||||
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[0]->setIndex(0);
|
||||
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[0]->setArrayLength(nidx);
|
||||
|
||||
// Shapes
|
||||
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[1]->setIndex(nidx + 1);
|
||||
|
||||
// Strides
|
||||
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[2]->setIndex(nidx + 2);
|
||||
|
||||
// Indices ndim
|
||||
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||
arg_descs[3]->setIndex(nidx + 3);
|
||||
|
||||
// Get the argument encoder
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy(
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end(),
|
||||
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||
std::copy(
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end(),
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
|
||||
// Register data with the encoder
|
||||
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
if (idx_ndim > 0) {
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
}
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, src, 0);
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(&slice_size, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
|
||||
// Set source info
|
||||
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
|
||||
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
|
||||
|
||||
// Set index info
|
||||
//
|
||||
// We don't need to check for empty idx_shapes because gather has a
|
||||
// idx_ndim == 0 specialization
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Cleanup temporaries
|
||||
arg_enc->release();
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||
allocator::free(arg_buf);
|
||||
allocator::free(idx_shapes_buf);
|
||||
allocator::free(idx_strides_buf);
|
||||
});
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -211,82 +172,35 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Make the argument buffer to store the indices for the
|
||||
// `Indices` struct in kernels/indexing.metal
|
||||
std::vector<MTL::ArgumentDescriptor*> arg_descs(4);
|
||||
arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[0]->setIndex(0);
|
||||
arg_descs[0]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[0]->setArrayLength(nidx);
|
||||
|
||||
// Shapes
|
||||
arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[1]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[1]->setIndex(nidx + 1);
|
||||
|
||||
// Strides
|
||||
arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[2]->setDataType(MTL::DataType::DataTypePointer);
|
||||
arg_descs[2]->setIndex(nidx + 2);
|
||||
|
||||
// Indices ndim
|
||||
arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor();
|
||||
arg_descs[3]->setDataType(MTL::DataType::DataTypeInt);
|
||||
arg_descs[3]->setIndex(nidx + 3);
|
||||
|
||||
// Get the argument encoder
|
||||
auto arg_enc = d.argument_encoder(arg_descs);
|
||||
|
||||
// Allocate and fill buffers for shapes and strides
|
||||
// Collect all idx shapes and strides into one place
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim);
|
||||
auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim);
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy(
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end(),
|
||||
static_cast<int*>(idx_shapes_buf.raw_ptr()) + i * idx_ndim);
|
||||
std::copy(
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end(),
|
||||
static_cast<size_t*>(idx_strides_buf.raw_ptr()) + i * idx_ndim);
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
// Allocate the argument buffer
|
||||
auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength());
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Register data with the encoder
|
||||
arg_enc->setArgumentBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i);
|
||||
}
|
||||
if (idx_ndim > 0) {
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()), 0, nidx + 1);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_shapes_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
arg_enc->setBuffer(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()), 0, nidx + 2);
|
||||
compute_encoder->useResource(
|
||||
static_cast<MTL::Buffer*>(idx_strides_buf.ptr()),
|
||||
MTL::ResourceUsageRead);
|
||||
}
|
||||
*static_cast<int*>(arg_enc->constantData(nidx + 3)) = idx_ndim;
|
||||
|
||||
compute_encoder->setBuffer(static_cast<MTL::Buffer*>(arg_buf.ptr()), 0, 0);
|
||||
// Set update info
|
||||
size_t upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
@@ -301,6 +215,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
@@ -316,16 +231,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Cleanup temporaries
|
||||
arg_enc->release();
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) {
|
||||
allocator::free(arg_buf);
|
||||
allocator::free(idx_shapes_buf);
|
||||
allocator::free(idx_strides_buf);
|
||||
});
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -6,6 +6,7 @@ set(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
)
|
||||
@@ -22,11 +23,13 @@ set(
|
||||
"quantized"
|
||||
"random"
|
||||
"reduce"
|
||||
"rope"
|
||||
"scan"
|
||||
"softmax"
|
||||
"sort"
|
||||
"unary"
|
||||
"indexing"
|
||||
"gather"
|
||||
"scatter"
|
||||
)
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
|
@@ -63,18 +63,6 @@ struct ArgMax {
|
||||
}
|
||||
};
|
||||
|
||||
bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||
}
|
||||
|
||||
uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
||||
return as_type<uint64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||
return as_type<int64_t>(simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
|
||||
return IndexValPair<U>(
|
||||
|
231
mlx/backend/metal/kernels/binary.h
Normal file
231
mlx/backend/metal/kernels/binary.h
Normal file
@@ -0,0 +1,231 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
|
||||
metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf)
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x != y;
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
@@ -1,145 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Add {
|
||||
template <typename T> T operator()(T x, T y) { return x + y; }
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T> T operator()(T x, T y) { return x / y; }
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T> bool operator()(T x, T y) { return x == y; }
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T> bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real)
|
||||
&& metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T> bool operator()(T x, T y) { return x > y; }
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x >= y; }
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T> bool operator()(T x, T y) { return x < y; }
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x <= y; }
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf) ? maxval :
|
||||
(maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T> T operator()(T x, T y) { return metal::max(x, y); }
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x >= y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T> T operator()(T x, T y) { return metal::min(x, y); }
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x <= y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T> T operator()(T x, T y) { return x * y; }
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T> bool operator()(T x, T y) { return x != y; }
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct Subtract {
|
||||
template <typename T> T operator()(T x, T y) { return x - y; }
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x && y; };
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) { return x || y; };
|
||||
};
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
@@ -389,4 +250,4 @@ instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||
|
||||
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
|
||||
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
||||
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
||||
|
@@ -14,10 +14,29 @@ struct FloorDivide {
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T> T operator()(T x, T y) { return x % y; }
|
||||
template <> float operator()(float x, float y) { return fmod(x, y); }
|
||||
template <> half operator()(half x, half y) { return fmod(x, y); }
|
||||
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return fmod(x, y); }
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <> complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
|
4
mlx/backend/metal/kernels/compiled_preamble.h
Normal file
4
mlx/backend/metal/kernels/compiled_preamble.h
Normal file
@@ -0,0 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
@@ -121,5 +121,11 @@ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
||||
if (real != 0 && (real < 0 != b.real < 0)) {
|
||||
real += b.real;
|
||||
}
|
||||
if (imag != 0 && (imag < 0 != b.imag < 0)) {
|
||||
imag += b.imag;
|
||||
}
|
||||
return {real, imag};
|
||||
}
|
||||
|
187
mlx/backend/metal/kernels/gather.metal
Normal file
187
mlx/backend/metal/kernels/gather.metal
Normal file
@@ -0,0 +1,187 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T *src [[buffer(0)]],
|
||||
device T *out [[buffer(1)]],
|
||||
const constant int *src_shape [[buffer(2)]],
|
||||
const constant size_t *src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int *slice_sizes [[buffer(5)]],
|
||||
const constant int *axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(
|
||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
|
||||
}
|
||||
|
||||
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
||||
[[kernel]] void gather( \
|
||||
const device T *src [[buffer(0)]], \
|
||||
device T *out [[buffer(1)]], \
|
||||
const constant int *src_shape [[buffer(2)]], \
|
||||
const constant size_t *src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int *slice_sizes [[buffer(5)]], \
|
||||
const constant int *axes [[buffer(6)]], \
|
||||
const constant int *idx_shapes [[buffer(7)]], \
|
||||
const constant size_t *idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]) { \
|
||||
\
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, \
|
||||
idx_shapes, \
|
||||
idx_strides, \
|
||||
idx_ndim}; \
|
||||
\
|
||||
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
||||
src, \
|
||||
out, \
|
||||
src_shape, \
|
||||
src_strides, \
|
||||
src_ndim, \
|
||||
slice_sizes, \
|
||||
axes, \
|
||||
idxs, \
|
||||
index, \
|
||||
grid_dim); \
|
||||
}
|
||||
|
||||
#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
|
||||
make_gather(0)
|
||||
make_gather(1)
|
||||
make_gather(2)
|
||||
make_gather(3)
|
||||
make_gather(4)
|
||||
make_gather(5)
|
||||
make_gather(6)
|
||||
make_gather(7)
|
||||
make_gather(8)
|
||||
make_gather(9)
|
||||
make_gather(10)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather instantiations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
||||
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
|
||||
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
|
||||
const device src_t *src [[buffer(0)]], \
|
||||
device src_t *out [[buffer(1)]], \
|
||||
const constant int *src_shape [[buffer(2)]], \
|
||||
const constant size_t *src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int *slice_sizes [[buffer(5)]], \
|
||||
const constant int *axes [[buffer(6)]], \
|
||||
const constant int *idx_shapes [[buffer(7)]], \
|
||||
const constant size_t *idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
||||
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
|
||||
|
||||
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
|
||||
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||
instantiate_gather4("int8", int8_t, bool, 0)
|
||||
instantiate_gather4("int16", int16_t, bool, 0)
|
||||
instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
||||
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10)
|
||||
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
instantiate_gather(uint8, uint8_t)
|
||||
instantiate_gather(uint16, uint16_t)
|
||||
instantiate_gather(uint32, uint32_t)
|
||||
instantiate_gather(uint64, uint64_t)
|
||||
instantiate_gather(int8, int8_t)
|
||||
instantiate_gather(int16, int16_t)
|
||||
instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t)
|
@@ -121,8 +121,18 @@ struct GEMVKernel {
|
||||
for(int tm = 0; tm < TM; tm++) {
|
||||
|
||||
// Load for the row
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
if(bn + TN <= in_vec_size) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
inter[tn] = mat[tm * in_vec_size + bn + tn];
|
||||
}
|
||||
|
||||
} else { // Edgecase
|
||||
#pragma clang loop unroll(full)
|
||||
for(int tn = 0; tn < TN; tn++) {
|
||||
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
|
||||
inter[tn] = mat[tm * in_vec_size + col_idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Accumulate results
|
||||
|
54
mlx/backend/metal/kernels/indexing.h
Normal file
54
mlx/backend/metal/kernels/indexing.h
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
const constant int* shapes;
|
||||
const constant size_t* strides;
|
||||
const int ndim;
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
if (is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
}
|
||||
|
||||
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
|
||||
|
||||
#define IDX_ARG_0(idx_t)
|
||||
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
|
||||
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
|
||||
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
|
||||
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
|
||||
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
|
||||
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
|
||||
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
|
||||
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
|
||||
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
|
||||
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
|
||||
|
||||
#define IDX_ARR_N(n) idx##n,
|
||||
|
||||
#define IDX_ARR_0()
|
||||
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
|
||||
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
|
||||
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
|
||||
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
|
||||
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
|
||||
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
|
||||
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
|
||||
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
|
||||
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
|
||||
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)
|
@@ -1,254 +0,0 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_texture>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<device IdxT*, NIDX> buffers [[id(0)]];
|
||||
device int* shapes [[id(NIDX + 1)]];
|
||||
device size_t* strides [[id(NIDX + 2)]];
|
||||
const int ndim [[id(NIDX + 3)]];
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, int NIDX>
|
||||
[[kernel]] void gather(
|
||||
const device T *src [[buffer(0)]],
|
||||
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
const device int *src_shape [[buffer(3)]],
|
||||
const device size_t *src_strides [[buffer(4)]],
|
||||
const device size_t& src_ndim [[buffer(5)]],
|
||||
const device int *slice_sizes [[buffer(6)]],
|
||||
const device size_t& slice_size [[buffer(7)]],
|
||||
const device int *axes [[buffer(8)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
|
||||
auto ind_idx = gid / slice_size;
|
||||
auto ind_offset = gid % slice_size;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(
|
||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
out[gid] = src[src_idx + src_offset];
|
||||
}
|
||||
|
||||
#define instantiate_gather4(name, src_type, ind_type, nindex) \
|
||||
template [[host_name("gather" name "_" #nindex)]] \
|
||||
[[kernel]] void gather( \
|
||||
const device src_type *src [[buffer(0)]], \
|
||||
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
device src_type *out [[buffer(2)]], \
|
||||
const device int *src_shape [[buffer(3)]], \
|
||||
const device size_t *src_strides [[buffer(4)]], \
|
||||
const device size_t& src_ndim [[buffer(5)]], \
|
||||
const device int *slice_sizes [[buffer(6)]], \
|
||||
const device size_t& slice_size [[buffer(7)]], \
|
||||
const device int* axes [[buffer(8)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||
instantiate_gather4("int8", int8_t, bool, 0)
|
||||
instantiate_gather4("int16", int16_t, bool, 0)
|
||||
instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
||||
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10)
|
||||
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
instantiate_gather(uint8, uint8_t)
|
||||
instantiate_gather(uint16, uint16_t)
|
||||
instantiate_gather(uint32, uint32_t)
|
||||
instantiate_gather(uint64, uint64_t)
|
||||
instantiate_gather(int8, int8_t)
|
||||
instantiate_gather(int16, int16_t)
|
||||
instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
[[kernel]] void scatter(
|
||||
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const device int *upd_shape [[buffer(3)]],
|
||||
const device size_t *upd_strides [[buffer(4)]],
|
||||
const device size_t& upd_ndim [[buffer(5)]],
|
||||
const device size_t& upd_size [[buffer(6)]],
|
||||
const device int *out_shape [[buffer(7)]],
|
||||
const device size_t *out_strides [[buffer(8)]],
|
||||
const device size_t& out_ndim [[buffer(9)]],
|
||||
const device int* axes [[buffer(10)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid / upd_size;
|
||||
auto ind_offset = gid % upd_size;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||
}
|
||||
|
||||
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
|
||||
template [[host_name("scatter" name "_" #nindex)]] \
|
||||
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
|
||||
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
|
||||
const device type *updates [[buffer(1)]], \
|
||||
device mlx_atomic<type> *out [[buffer(2)]], \
|
||||
const device int *upd_shape [[buffer(3)]], \
|
||||
const device size_t *upd_strides [[buffer(4)]], \
|
||||
const device size_t& upd_ndim [[buffer(5)]], \
|
||||
const device size_t& upd_size [[buffer(6)]], \
|
||||
const device int *out_shape [[buffer(7)]], \
|
||||
const device size_t *out_strides [[buffer(8)]], \
|
||||
const device size_t& out_ndim [[buffer(9)]], \
|
||||
const device int* axes [[buffer(10)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
||||
|
||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
||||
|
||||
#define instantiate_scatter2(name, type, ind_type) \
|
||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
||||
|
||||
#define instantiate_scatter(name, type) \
|
||||
instantiate_scatter2(#name "bool_", type, bool) \
|
||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
||||
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
||||
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||
instantiate_scatter2(#name "int64", type, int64_t)
|
||||
|
||||
// TODO uint64 and int64 unsupported
|
||||
instantiate_scatter_nd0(bool_, bool)
|
||||
instantiate_scatter_nd0(uint8, uint8_t)
|
||||
instantiate_scatter_nd0(uint16, uint16_t)
|
||||
instantiate_scatter_nd0(uint32, uint32_t)
|
||||
instantiate_scatter_nd0(int8, int8_t)
|
||||
instantiate_scatter_nd0(int16, int16_t)
|
||||
instantiate_scatter_nd0(int32, int32_t)
|
||||
instantiate_scatter_nd0(float16, half)
|
||||
instantiate_scatter_nd0(float32, float)
|
||||
instantiate_scatter_nd0(bfloat16, bfloat16_t)
|
||||
|
||||
instantiate_scatter(bool_, bool)
|
||||
instantiate_scatter(uint8, uint8_t)
|
||||
instantiate_scatter(uint16, uint16_t)
|
||||
instantiate_scatter(uint32, uint32_t)
|
||||
instantiate_scatter(int8, int8_t)
|
||||
instantiate_scatter(int16, int16_t)
|
||||
instantiate_scatter(int32, int32_t)
|
||||
instantiate_scatter(float16, half)
|
||||
instantiate_scatter(float32, float)
|
||||
instantiate_scatter(bfloat16, bfloat16_t)
|
@@ -15,6 +15,14 @@ using namespace metal;
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T> struct AccT {
|
||||
typedef T acc_t;
|
||||
};
|
||||
|
||||
template <> struct AccT<bfloat16_t> {
|
||||
typedef float acc_t;
|
||||
};
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
@@ -31,21 +39,23 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_thread = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_thread;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
threadgroup T biases_block[BM * groups_per_block];
|
||||
threadgroup T x_block[colgroup];
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[colgroup];
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread T result = 0;
|
||||
thread T scale = 1;
|
||||
thread T bias = 0;
|
||||
thread T x_thread[el_per_thread];
|
||||
thread U result = 0;
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_thread[el_per_thread];
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / el_per_thread;
|
||||
@@ -57,12 +67,19 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size;
|
||||
|
||||
if (out_row >= out_vec_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=colgroup) {
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid < simdgroups_fetching_vec) {
|
||||
x_block[lid] = x[lid + i];
|
||||
if (simd_gid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<el_per_thread; j++) {
|
||||
x_block[simd_lid * el_per_thread + j] = x[i + simd_lid * el_per_thread + j];
|
||||
}
|
||||
}
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -90,7 +107,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_thread; k++) {
|
||||
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
|
||||
result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
@@ -100,7 +117,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
y[out_row] = result;
|
||||
y[out_row] = static_cast<T>(result);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,23 +146,25 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
constexpr int colgroup = BN * el_per_int;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
threadgroup T biases_block[BM * groups_per_block];
|
||||
threadgroup T x_block[BM];
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[BM];
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread T result[el_per_int] = {0};
|
||||
thread T scale = 1;
|
||||
thread T bias = 0;
|
||||
thread T x_local = 0;
|
||||
thread U result[el_per_int] = {0};
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col = (tid.y * BN + simd_gid) * el_per_int;
|
||||
int out_col_start = tid.y * (BN * el_per_int);
|
||||
int out_col = out_col_start + simd_gid * el_per_int;
|
||||
w += out_col / el_per_int;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
scales += out_col_start / group_size;
|
||||
biases += out_col_start / group_size;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size + out_col;
|
||||
|
||||
@@ -155,26 +174,22 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
int offset = simd_lid + i;
|
||||
bool thread_in_bounds = offset < in_vec_size;
|
||||
int offset_lid = simd_lid + i;
|
||||
int offset_gid = simd_gid + i;
|
||||
bool thread_in_bounds = offset_lid < in_vec_size;
|
||||
bool group_in_bounds = offset_gid < in_vec_size;
|
||||
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0;
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
|
||||
}
|
||||
if (simd_lid < groups_per_block && group_in_bounds) {
|
||||
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
|
||||
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -184,12 +199,12 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0;
|
||||
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] += (scale * static_cast<T>(w_local & bitmask) + bias) * x_local;
|
||||
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
@@ -204,7 +219,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
y[k] = result[k];
|
||||
y[k] = static_cast<T>(result[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -243,7 +258,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
using mma_t = mlx::steel::BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK, BK>;
|
||||
using loader_x_t = mlx::steel::BlockLoader<T, BM, BK, BK, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
||||
|
||||
|
||||
threadgroup T scales_block[BN * groups_per_block];
|
||||
threadgroup T biases_block[BN * groups_per_block];
|
||||
threadgroup T Xs[BM * BK];
|
||||
@@ -306,7 +320,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
if (y_col + offset_col < N) {
|
||||
if (y_row + offset_row < N) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
@@ -421,8 +435,9 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
for (int k=0; k<K; k += BK) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Load the x tile
|
||||
if (num_els < BM) {
|
||||
loader_x.load_safe(short2(BK, num_els));
|
||||
short num_k = min(BK, K - k);
|
||||
if (num_els < BM || num_k < BK) {
|
||||
loader_x.load_safe(short2(num_k, num_els));
|
||||
} else {
|
||||
loader_x.load_unsafe();
|
||||
}
|
||||
@@ -450,7 +465,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
// Load the w tile
|
||||
{
|
||||
if (k + BK >= K) {
|
||||
if (num_k < BK) {
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BN / el_per_int);
|
||||
@@ -543,6 +558,9 @@ instantiate_qmv_types(128, 8)
|
||||
instantiate_qmv_types( 64, 2)
|
||||
instantiate_qmv_types( 64, 4)
|
||||
instantiate_qmv_types( 64, 8)
|
||||
instantiate_qmv_types( 32, 2)
|
||||
instantiate_qmv_types( 32, 4)
|
||||
instantiate_qmv_types( 32, 8)
|
||||
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
@@ -570,6 +588,9 @@ instantiate_qvm_types(128, 8)
|
||||
instantiate_qvm_types( 64, 2)
|
||||
instantiate_qvm_types( 64, 4)
|
||||
instantiate_qvm_types( 64, 8)
|
||||
instantiate_qvm_types( 32, 2)
|
||||
instantiate_qvm_types( 32, 4)
|
||||
instantiate_qvm_types( 32, 8)
|
||||
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
||||
@@ -601,6 +622,9 @@ instantiate_qmm_t_types(128, 8)
|
||||
instantiate_qmm_t_types( 64, 2)
|
||||
instantiate_qmm_t_types( 64, 4)
|
||||
instantiate_qmm_t_types( 64, 8)
|
||||
instantiate_qmm_t_types( 32, 2)
|
||||
instantiate_qmm_t_types( 32, 4)
|
||||
instantiate_qmm_t_types( 32, 8)
|
||||
|
||||
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
@@ -629,3 +653,6 @@ instantiate_qmm_n_types(128, 8)
|
||||
instantiate_qmm_n_types( 64, 2)
|
||||
instantiate_qmm_n_types( 64, 4)
|
||||
instantiate_qmm_n_types( 64, 8)
|
||||
instantiate_qmm_n_types( 32, 2)
|
||||
instantiate_qmm_n_types( 32, 4)
|
||||
instantiate_qmm_n_types( 32, 8)
|
||||
|
@@ -24,11 +24,59 @@ template <typename T, typename Op>
|
||||
device otype *out [[buffer(1)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
inline U per_thread_all_reduce(
|
||||
const device T *in,
|
||||
const device size_t& in_size,
|
||||
uint gid,
|
||||
uint grid_size) {
|
||||
Op op;
|
||||
U total_val = Op::init;
|
||||
|
||||
if (gid * N_READS < in_size) {
|
||||
in += gid * N_READS;
|
||||
|
||||
int r = 0;
|
||||
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
||||
U vals[N_READS] = {op.init};
|
||||
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = static_cast<U>(in[i]);
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
total_val = op(vals[i], total_val);
|
||||
}
|
||||
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
T vals[N_READS];
|
||||
|
||||
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
||||
idx = idx < max_reads ? idx : max_reads - 1;
|
||||
vals[i] = in[idx];
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
U val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return total_val;
|
||||
}
|
||||
|
||||
// NB: This kernel assumes threads_per_threadgroup is at most
|
||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||
// complete the reduction in two steps of simd-level reductions.
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce(
|
||||
const device T *in [[buffer(0)]],
|
||||
@@ -40,53 +88,18 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
// NB: this kernel assumes threads_per_threadgroup is at most
|
||||
// 1024. This way with a simd_size of 32, we are guaranteed to
|
||||
// complete the reduction in two steps of simd-level reductions.
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = Op::init;
|
||||
|
||||
in += gid * N_READS;
|
||||
|
||||
int r = 0;
|
||||
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
|
||||
U vals[N_READS] = {op.init};
|
||||
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = static_cast<U>(in[i]);
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
total_val = op(vals[i], total_val);
|
||||
}
|
||||
|
||||
in += grid_size * N_READS;
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
|
||||
if (curr_idx < in_size) {
|
||||
int max_reads = in_size - curr_idx;
|
||||
T vals[N_READS];
|
||||
|
||||
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
|
||||
idx = idx < max_reads ? idx : max_reads - 1;
|
||||
vals[i] = in[idx];
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
U val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
|
||||
// Reduction within simd group
|
||||
total_val = op.simd_reduce(total_val);
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
|
||||
|
||||
// Reduction within thread group
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
@@ -98,6 +111,46 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void all_reduce_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
|
||||
|
||||
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
|
||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
// Write simd group reduction results to local memory
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction of simdgroup reduction results within threadgroup.
|
||||
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
|
||||
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
|
||||
}
|
||||
|
||||
// Reduction across threadgroups
|
||||
if (lid == 0) {
|
||||
out[thread_group_id] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] \
|
||||
[[kernel]] void all_reduce<itype, otype, op>( \
|
||||
@@ -111,11 +164,80 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_no_atomics_" #name)]] \
|
||||
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Row atomics
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
inline U per_thread_row_reduce(
|
||||
const device T *in,
|
||||
const constant size_t& reduction_size,
|
||||
const constant size_t& out_size,
|
||||
const constant int* shape,
|
||||
const constant size_t* strides,
|
||||
const constant int& ndim,
|
||||
uint lsize_x,
|
||||
uint lid_x,
|
||||
uint2 tid) {
|
||||
|
||||
Op op;
|
||||
|
||||
// Each threadgroup handles 1 reduction
|
||||
// TODO: Specializing elem_to_loc would be slightly faster
|
||||
int idx = tid.y * out_size + tid.x;
|
||||
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||
in += extra_offset + lid_x * N_READS;
|
||||
|
||||
// The reduction is accumulated here
|
||||
U total_val = Op::init;
|
||||
|
||||
// Loop over the reduction size within thread group
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) {
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[i];
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||
}
|
||||
|
||||
in += lsize_x * N_READS;
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
|
||||
if(reduction_index < reduction_size) {
|
||||
int max_reads = reduction_size - reduction_index;
|
||||
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
int idx = min(i, max_reads - 1);
|
||||
vals[i] = static_cast<U>(in[idx]);
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
T val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
|
||||
return total_val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general(
|
||||
const device T *in [[buffer(0)]],
|
||||
@@ -133,46 +255,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
// Each threadgroup handles 1 reduction
|
||||
// TODO: Specializing elem_to_loc would be slightly faster
|
||||
int idx = tid.y * out_size + tid.x;
|
||||
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
|
||||
in += extra_offset + lid.x * N_READS;
|
||||
|
||||
// The reduction is accumulated here
|
||||
U total_val = Op::init;
|
||||
threadgroup U local_vals[simd_size];
|
||||
|
||||
// Loop over the reduction size within thread group
|
||||
int r = 0;
|
||||
for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) {
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
vals[i] = in[i];
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
total_val = op(static_cast<U>(vals[i]), total_val);
|
||||
}
|
||||
|
||||
in += lsize.x * N_READS;
|
||||
}
|
||||
|
||||
// Separate case for the last set as we close the reduction size
|
||||
size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS;
|
||||
if(reduction_index < reduction_size) {
|
||||
int max_reads = reduction_size - reduction_index;
|
||||
|
||||
T vals[N_READS];
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
int idx = min(i, max_reads - 1);
|
||||
vals[i] = static_cast<U>(in[idx]);
|
||||
}
|
||||
for(int i = 0; i < N_READS; i++) {
|
||||
T val = i < max_reads ? vals[i] : Op::init;
|
||||
total_val = op(static_cast<U>(val), total_val);
|
||||
}
|
||||
}
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
||||
|
||||
total_val = op.simd_reduce(total_val);
|
||||
|
||||
@@ -194,6 +279,53 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
[[kernel]] void row_reduce_general_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant int* shape [[buffer(4)]],
|
||||
const constant size_t* strides [[buffer(5)]],
|
||||
const constant int& ndim [[buffer(6)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
threadgroup U local_vals[simd_size];
|
||||
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
|
||||
|
||||
// Reduction within simd group - simd_add isn't supported for int64 types
|
||||
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
}
|
||||
|
||||
// Prepare next level
|
||||
if (simd_lane_id == 0) {
|
||||
local_vals[simd_group_id] = total_val;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Reduction within thread group
|
||||
// Only needed if thread group has multiple simd groups
|
||||
if(ceildiv(reduction_size, N_READS) > simd_size) {
|
||||
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
|
||||
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
|
||||
total_val = op(total_val, simd_shuffle_down(total_val, i));
|
||||
}
|
||||
}
|
||||
// Write row reduce output for threadgroup with 1st thread in thread group
|
||||
if (lid.x == 0) {
|
||||
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_" #name)]] \
|
||||
[[kernel]] void row_reduce_general<itype, otype, op>( \
|
||||
@@ -211,52 +343,59 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
|
||||
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant int* shape [[buffer(4)]], \
|
||||
const constant size_t* strides [[buffer(5)]], \
|
||||
const constant int& ndim [[buffer(6)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
inline void _contiguous_strided_reduce(
|
||||
const device T *in,
|
||||
device mlx_atomic<U> *out,
|
||||
threadgroup U *local_data,
|
||||
uint in_idx,
|
||||
uint out_idx,
|
||||
uint reduction_size,
|
||||
uint reduction_stride,
|
||||
uint2 tid,
|
||||
uint2 lid,
|
||||
inline U _contiguous_strided_reduce(
|
||||
const device T *in,
|
||||
threadgroup U *local_data,
|
||||
uint in_idx,
|
||||
uint reduction_size,
|
||||
uint reduction_stride,
|
||||
uint2 tid,
|
||||
uint2 lid,
|
||||
uint2 lsize) {
|
||||
|
||||
Op op;
|
||||
T local_vals[N_READS];
|
||||
U total_val = Op::init;
|
||||
|
||||
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
|
||||
|
||||
for(uint r = 0; r < N_READS; r++) {
|
||||
uint offset = base_offset + r;
|
||||
offset = offset < reduction_size ? offset : reduction_size - 1;
|
||||
local_vals[r] = in[in_idx + offset * reduction_stride];
|
||||
}
|
||||
|
||||
U total_val = Op::init;
|
||||
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
|
||||
total_val = op(static_cast<U>(total_val), local_vals[r]);
|
||||
uint offset = base_offset + r;
|
||||
total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
|
||||
}
|
||||
local_data[lsize.y * lid.x + lid.y] = total_val;
|
||||
|
||||
local_data[lsize.y * lid.x + lid.y] = total_val;
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
U val = Op::init;
|
||||
if(lid.y == 0) {
|
||||
U val = op.init;
|
||||
|
||||
// Perform reduction across columns in thread group
|
||||
for(uint i = 0; i < lsize.y; i++) {
|
||||
val = op(val, local_data[lsize.y * lid.x + i]);
|
||||
val = op(val, local_data[lsize.y * lid.x + i]);
|
||||
}
|
||||
|
||||
op.atomic_update(out, val, out_idx);
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
@@ -265,13 +404,13 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
device mlx_atomic<U> *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(
|
||||
@@ -281,18 +420,66 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
ndim
|
||||
);
|
||||
|
||||
Op op;
|
||||
if(out_idx < out_size) {
|
||||
_contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
out,
|
||||
local_data,
|
||||
in_idx,
|
||||
out_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
op.atomic_update(out, val, out_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
[[kernel]] void col_reduce_general_no_atomics(
|
||||
const device T *in [[buffer(0)]],
|
||||
device U *out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup U *local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]) {
|
||||
auto out_idx = tid.x * lsize.x + lid.x;
|
||||
auto in_idx = elem_to_loc(
|
||||
out_idx + tid.z * out_size,
|
||||
shape,
|
||||
strides,
|
||||
ndim
|
||||
);
|
||||
|
||||
if(out_idx < out_size) {
|
||||
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
|
||||
in,
|
||||
local_data,
|
||||
in_idx,
|
||||
reduction_size,
|
||||
reduction_stride,
|
||||
tid.xy,
|
||||
lid.xy,
|
||||
lsize.xy);
|
||||
|
||||
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
|
||||
if (lid.y == 0) {
|
||||
uint tgsize_y = ceildiv(gsize.y, lsize.y);
|
||||
uint tgsize_z = ceildiv(gsize.z, lsize.z);
|
||||
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,6 +499,23 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
|
||||
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype *in [[buffer(0)]], \
|
||||
device otype *out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype *local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
@@ -322,6 +526,15 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_col_reduce_general(name, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_no_atomics(name, itype, otype, op) \
|
||||
instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_col_reduce_general_no_atomics(name, itype, otype, op)
|
||||
|
||||
#define instantiate_same_reduce_no_atomics(name, tname, type, op) \
|
||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||
instantiate_reduce_no_atomics(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_reduce(name, tname, type, op) \
|
||||
instantiate_init_reduce(name ##tname, type, op<type>) \
|
||||
instantiate_reduce(name ##tname, type, type, op<type>)
|
||||
@@ -353,6 +566,9 @@ instantiate_same_reduce(sum, int32, int32_t, Sum)
|
||||
instantiate_same_reduce(sum, float16, half, Sum)
|
||||
instantiate_same_reduce(sum, float32, float, Sum)
|
||||
|
||||
instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum)
|
||||
instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum)
|
||||
|
||||
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
|
||||
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
|
||||
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
|
||||
@@ -362,6 +578,9 @@ instantiate_same_reduce(prod, int32, int32_t, Prod)
|
||||
instantiate_same_reduce(prod, float16, half, Prod)
|
||||
instantiate_same_reduce(prod, float32, float, Prod)
|
||||
|
||||
instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod)
|
||||
instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod)
|
||||
|
||||
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
|
||||
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
|
||||
|
||||
@@ -381,6 +600,9 @@ instantiate_same_reduce(min_, int32, int32_t, Min)
|
||||
instantiate_same_reduce(min_, float16, half, Min)
|
||||
instantiate_same_reduce(min_, float32, float, Min)
|
||||
|
||||
instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min)
|
||||
instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min)
|
||||
|
||||
instantiate_same_reduce(max_, uint8, uint8_t, Max)
|
||||
instantiate_same_reduce(max_, uint16, uint16_t, Max)
|
||||
instantiate_same_reduce(max_, uint32, uint32_t, Max)
|
||||
@@ -390,5 +612,8 @@ instantiate_same_reduce(max_, int32, int32_t, Max)
|
||||
instantiate_same_reduce(max_, float16, half, Max)
|
||||
instantiate_same_reduce(max_, float32, float, Max)
|
||||
|
||||
instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max)
|
||||
instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max)
|
||||
|
||||
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
|
||||
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)
|
||||
|
68
mlx/backend/metal/kernels/rope.metal
Normal file
68
mlx/backend/metal/kernels/rope.metal
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, bool traditional>
|
||||
[[kernel]] void rope(
|
||||
const device T *in [[buffer(0)]],
|
||||
device T * out [[buffer(1)]],
|
||||
constant const size_t strides[3],
|
||||
constant const int& offset,
|
||||
constant const float& base,
|
||||
constant const float& scale,
|
||||
uint3 pos [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Compute the input and output indices
|
||||
uint in_index_1, in_index_2;
|
||||
uint out_index_1, out_index_2;
|
||||
if (traditional) {
|
||||
out_index_1 = 2 * (pos.x + grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_2 = out_index_1 + 1;
|
||||
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + strides[2];
|
||||
} else {
|
||||
out_index_1 = pos.x + 2*(grid.x * (pos.y + grid.y * pos.z));
|
||||
out_index_2 = out_index_1 + grid.x;
|
||||
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
|
||||
in_index_2 = in_index_1 + grid.x * strides[2];
|
||||
}
|
||||
|
||||
// Figure out L and d.
|
||||
float L = scale * static_cast<float>(pos.y + offset);
|
||||
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
|
||||
|
||||
// Compute costheta, sintheta
|
||||
float theta = L * metal::exp2(-d * base);
|
||||
float costheta = metal::fast::cos(theta);
|
||||
float sintheta = metal::fast::sin(theta);
|
||||
|
||||
// Read and write the output
|
||||
float x1 = static_cast<float>(in[in_index_1]);
|
||||
float x2 = static_cast<float>(in[in_index_2]);
|
||||
float rx1 = x1 * costheta - x2 * sintheta;
|
||||
float rx2 = x1 * sintheta + x2 * costheta;
|
||||
out[out_index_1] = static_cast<T>(rx1);
|
||||
out[out_index_2] = static_cast<T>(rx2);
|
||||
}
|
||||
|
||||
#define instantiate_rope(name, type, traditional) \
|
||||
template [[host_name("rope_" #name)]] \
|
||||
[[kernel]] void rope<type, traditional>( \
|
||||
const device type* in [[buffer(0)]], \
|
||||
device type* out [[buffer(1)]], \
|
||||
constant const size_t strides[3], \
|
||||
constant const int& offset, \
|
||||
constant const float& base, \
|
||||
constant const float& scale, \
|
||||
uint3 pos [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
instantiate_rope(traditional_float16, half, true)
|
||||
instantiate_rope(traditional_bfloat16, bfloat16_t, true)
|
||||
instantiate_rope(traditional_float32, float, true)
|
||||
instantiate_rope(float16, half, false)
|
||||
instantiate_rope(bfloat16, bfloat16_t, false)
|
||||
instantiate_rope(float32, float, false)
|
194
mlx/backend/metal/kernels/scatter.metal
Normal file
194
mlx/backend/metal/kernels/scatter.metal
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_impl(
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const constant int *upd_shape [[buffer(3)]],
|
||||
const constant size_t *upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int *out_shape [[buffer(7)]],
|
||||
const constant size_t *out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid.y;
|
||||
auto ind_offset = gid.x;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||
}
|
||||
|
||||
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
[[kernel]] void scatter( \
|
||||
const device T *updates [[buffer(1)]], \
|
||||
device mlx_atomic<T> *out [[buffer(2)]], \
|
||||
const constant int *upd_shape [[buffer(3)]], \
|
||||
const constant size_t *upd_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||
const constant size_t& upd_size [[buffer(6)]], \
|
||||
const constant int *out_shape [[buffer(7)]], \
|
||||
const constant size_t *out_strides [[buffer(8)]], \
|
||||
const constant size_t& out_ndim [[buffer(9)]], \
|
||||
const constant int* axes [[buffer(10)]], \
|
||||
const constant int *idx_shapes [[buffer(11)]], \
|
||||
const constant size_t *idx_strides [[buffer(12)]], \
|
||||
const constant int& idx_ndim [[buffer(13)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 gid [[thread_position_in_grid]]) { \
|
||||
\
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, \
|
||||
idx_shapes, \
|
||||
idx_strides, \
|
||||
idx_ndim}; \
|
||||
\
|
||||
return scatter_impl<T, IdxT, Op, NIDX>( \
|
||||
updates, \
|
||||
out, \
|
||||
upd_shape, \
|
||||
upd_strides, \
|
||||
upd_ndim, \
|
||||
upd_size, \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
out_ndim, \
|
||||
axes, \
|
||||
idxs, \
|
||||
gid); \
|
||||
}
|
||||
|
||||
#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
|
||||
make_scatter(0)
|
||||
make_scatter(1)
|
||||
make_scatter(2)
|
||||
make_scatter(3)
|
||||
make_scatter(4)
|
||||
make_scatter(5)
|
||||
make_scatter(6)
|
||||
make_scatter(7)
|
||||
make_scatter(8)
|
||||
make_scatter(9)
|
||||
make_scatter(10)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter instantiations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||
template [[host_name("scatter" name "_" #nidx)]] \
|
||||
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
|
||||
const device src_t *updates [[buffer(1)]], \
|
||||
device mlx_atomic<src_t> *out [[buffer(2)]], \
|
||||
const constant int *upd_shape [[buffer(3)]], \
|
||||
const constant size_t *upd_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_ndim [[buffer(5)]], \
|
||||
const constant size_t& upd_size [[buffer(6)]], \
|
||||
const constant int *out_shape [[buffer(7)]], \
|
||||
const constant size_t *out_strides [[buffer(8)]], \
|
||||
const constant size_t& out_ndim [[buffer(9)]], \
|
||||
const constant int* axes [[buffer(10)]], \
|
||||
const constant int *idx_shapes [[buffer(11)]], \
|
||||
const constant size_t *idx_strides [[buffer(12)]], \
|
||||
const constant int& idx_ndim [[buffer(13)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
||||
|
||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
||||
|
||||
#define instantiate_scatter2(name, type, ind_type) \
|
||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
||||
|
||||
#define instantiate_scatter(name, type) \
|
||||
instantiate_scatter2(#name "bool_", type, bool) \
|
||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
||||
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
||||
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||
instantiate_scatter2(#name "int64", type, int64_t)
|
||||
|
||||
// TODO uint64 and int64 unsupported
|
||||
instantiate_scatter_nd0(bool_, bool)
|
||||
instantiate_scatter_nd0(uint8, uint8_t)
|
||||
instantiate_scatter_nd0(uint16, uint16_t)
|
||||
instantiate_scatter_nd0(uint32, uint32_t)
|
||||
instantiate_scatter_nd0(int8, int8_t)
|
||||
instantiate_scatter_nd0(int16, int16_t)
|
||||
instantiate_scatter_nd0(int32, int32_t)
|
||||
instantiate_scatter_nd0(float16, half)
|
||||
instantiate_scatter_nd0(float32, float)
|
||||
instantiate_scatter_nd0(bfloat16, bfloat16_t)
|
||||
|
||||
instantiate_scatter(bool_, bool)
|
||||
instantiate_scatter(uint8, uint8_t)
|
||||
instantiate_scatter(uint16, uint16_t)
|
||||
instantiate_scatter(uint32, uint32_t)
|
||||
instantiate_scatter(int8, int8_t)
|
||||
instantiate_scatter(int16, int16_t)
|
||||
instantiate_scatter(int32, int32_t)
|
||||
instantiate_scatter(float16, half)
|
||||
instantiate_scatter(float32, float)
|
||||
instantiate_scatter(bfloat16, bfloat16_t)
|
@@ -89,20 +89,9 @@ struct GEMMKernel {
|
||||
// Appease the compiler
|
||||
(void)l;
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
|
||||
if (!M_aligned) {
|
||||
short2 tile_dims_A =
|
||||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
}
|
||||
|
||||
if (!N_aligned) {
|
||||
short2 tile_dims_B =
|
||||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
}
|
||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
||||
|
||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -110,13 +99,13 @@ struct GEMMKernel {
|
||||
if (M_aligned) {
|
||||
loader_a.load_unsafe();
|
||||
} else {
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
}
|
||||
|
||||
if (N_aligned) {
|
||||
loader_b.load_unsafe();
|
||||
} else {
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -137,11 +126,8 @@ struct GEMMKernel {
|
||||
short2 tile_dims_B_last =
|
||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
||||
|
||||
loader_a.set_mask(tile_dims_A_last, mask_A);
|
||||
loader_b.set_mask(tile_dims_B_last, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_a.load_safe(tile_dims_A_last);
|
||||
loader_b.load_safe(tile_dims_B_last);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -218,14 +204,8 @@ struct GEMMKernel {
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
|
@@ -112,14 +112,8 @@ template <typename T,
|
||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
||||
|
||||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
|
||||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];
|
||||
|
||||
loader_a.set_mask(tile_dims_A, mask_A);
|
||||
loader_b.set_mask(tile_dims_B, mask_B);
|
||||
|
||||
loader_a.load_safe(mask_A);
|
||||
loader_b.load_safe(mask_B);
|
||||
loader_a.load_safe(tile_dims_A);
|
||||
loader_b.load_safe(tile_dims_B);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
|
@@ -67,24 +67,22 @@ struct BlockLoader {
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - without bound checking */
|
||||
METAL_FUNC void set_mask(
|
||||
thread const short2& src_tile_dims,
|
||||
thread bool mask[n_rows][vec_size]) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < n_rows; i++) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
mask[i][j] =
|
||||
((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
||||
|
||||
// Skip loading if thread has no valid reads
|
||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0; i < BROWS; i += TROWS) {
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = T(0);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Use fast thread memory for bound checks
|
||||
bool tmp_idx[vec_size];
|
||||
T tmp_val[vec_size];
|
||||
@@ -117,39 +115,6 @@ struct BlockLoader {
|
||||
}
|
||||
}
|
||||
|
||||
/* Load from device memory into threadgroup memory - with bound checking */
|
||||
METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const {
|
||||
T tmp_val[vec_size];
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
// Use fast thread memory for bound checks
|
||||
|
||||
// Read valid indices into tmp_val
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)];
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Zero out uneeded values
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Copy values to threadgroup memory
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (short j = 0; j < vec_size; j++) {
|
||||
dst[i * dst_ld + j] = tmp_val[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Iteration helper */
|
||||
METAL_FUNC void next() {
|
||||
src += tile_stride;
|
||||
|
376
mlx/backend/metal/kernels/unary.h
Normal file
376
mlx/backend/metal/kernels/unary.h
Normal file
@@ -0,0 +1,376 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
struct Abs {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::abs(x);
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::acos(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::acosh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::asin(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::asinh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::atan(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::atanh(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::ceil(x);
|
||||
};
|
||||
template <>
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::cos(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::cosh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(erf(static_cast<float>(x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(erfinv(static_cast<float>(x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::exp(x);
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::floor(x);
|
||||
};
|
||||
template <>
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
};
|
||||
template <>
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::log(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::log2(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::log10(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return log1p(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return !x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return -x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::rint(x);
|
||||
};
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {metal::rint(x.real), metal::rint(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return (x > T(0)) - (x < T(0));
|
||||
};
|
||||
template <>
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x != 0;
|
||||
};
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::sin(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::sinh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)};
|
||||
};
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x * x;
|
||||
};
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::sqrt(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::rsqrt(x);
|
||||
};
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::tan(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tan_a = metal::precise::tan(x.real);
|
||||
float tanh_b = metal::precise::tanh(x.imag);
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom};
|
||||
};
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return metal::precise::tanh(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tanh_a = metal::precise::tanh(x.real);
|
||||
float tan_b = metal::precise::tan(x.imag);
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom};
|
||||
};
|
||||
};
|
@@ -1,223 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/erf.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
struct Abs {
|
||||
template <typename T> T operator()(T x) { return metal::abs(x); };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
template <> complex64_t operator()(complex64_t x) {
|
||||
return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0};
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::acos(x); };
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::acosh(x); };
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T> T operator()(T x) { return metal::precise::asin(x); };
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::asinh(x); };
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atan(x); };
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::atanh(x); };
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T> T operator()(T x) { return metal::ceil(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Cos {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cos(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cos(x.real) * metal::precise::cosh(x.imag),
|
||||
-metal::precise::sin(x.real) * metal::precise::sinh(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::cosh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::cosh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::sinh(x.real) * metal::precise::sin(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T> T operator()(T x) { return static_cast<T>(erf(static_cast<float>(x))); };
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T> T operator()(T x) { return static_cast<T>(erfinv(static_cast<float>(x))); };
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T> T operator()(T x) { return metal::precise::exp(x); };
|
||||
template <> complex64_t operator()(complex64_t x) {
|
||||
auto m = metal::precise::exp(x.real);
|
||||
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T> T operator()(T x) { return metal::floor(x); };
|
||||
template <> int8_t operator()(int8_t x) { return x; };
|
||||
template <> int16_t operator()(int16_t x) { return x; };
|
||||
template <> int32_t operator()(int32_t x) { return x; };
|
||||
template <> int64_t operator()(int64_t x) { return x; };
|
||||
template <> uint8_t operator()(uint8_t x) { return x; };
|
||||
template <> uint16_t operator()(uint16_t x) { return x; };
|
||||
template <> uint32_t operator()(uint32_t x) { return x; };
|
||||
template <> uint64_t operator()(uint64_t x) { return x; };
|
||||
template <> bool operator()(bool x) { return x; };
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log(x); };
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log2(x); };
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T> T operator()(T x) { return metal::precise::log10(x); };
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T> T operator()(T x) { return log1p(x); };
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T> T operator()(T x) { return !x; };
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T> T operator()(T x) { return -x; };
|
||||
};
|
||||
|
||||
struct Round {
|
||||
template <typename T> T operator()(T x) { return metal::rint(x); };
|
||||
template <> complex64_t operator()(complex64_t x) { return {metal::rint(x.real), metal::rint(x.imag)}; };
|
||||
};
|
||||
|
||||
struct Sigmoid {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
auto y = 1 / (1 + metal::exp(-metal::abs(x)));
|
||||
return (x < 0) ? 1 - y : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Sign {
|
||||
template <typename T> T operator()(T x) { return (x > T(0)) - (x < T(0)); };
|
||||
template <> uint32_t operator()(uint32_t x) { return x != 0; };
|
||||
};
|
||||
|
||||
struct Sin {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sin(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sin(x.real) * metal::precise::cosh(x.imag),
|
||||
metal::precise::cos(x.real) * metal::precise::sinh(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sinh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return {
|
||||
metal::precise::sinh(x.real) * metal::precise::cos(x.imag),
|
||||
metal::precise::cosh(x.real) * metal::precise::sin(x.imag)
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T> T operator()(T x) { return x * x; };
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T> T operator()(T x) { return metal::precise::sqrt(x); };
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T> T operator()(T x) { return metal::precise::rsqrt(x); };
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T> T operator()(T x) { return metal::precise::tan(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tan_a = metal::precise::tan(x.real);
|
||||
float tanh_b = metal::precise::tanh(x.imag);
|
||||
float t1 = tan_a * tanh_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {
|
||||
(tan_a - tanh_b * t1) / denom,
|
||||
(tanh_b + tan_a * t1) / denom
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T> T operator()(T x) { return metal::precise::tanh(x); };
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x) {
|
||||
float tanh_a = metal::precise::tanh(x.real);
|
||||
float tan_b = metal::precise::tan(x.imag);
|
||||
float t1 = tanh_a * tan_b;
|
||||
float denom = 1. + t1 * t1;
|
||||
return {
|
||||
(tanh_a + tan_b * t1) / denom,
|
||||
(tan_b - tanh_a * t1) / denom
|
||||
};
|
||||
};
|
||||
};
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void unary_op_v(
|
||||
|
@@ -12,10 +12,10 @@
|
||||
|
||||
template <typename U>
|
||||
struct Limits {
|
||||
static const constant U max;
|
||||
static const constant U min;
|
||||
static const constant U finite_max;
|
||||
static const constant U finite_min;
|
||||
static const constant U max = metal::numeric_limits<U>::max();
|
||||
static const constant U min = metal::numeric_limits<U>::min();
|
||||
static const constant U finite_max = metal::numeric_limits<U>::max();
|
||||
static const constant U finite_min = metal::numeric_limits<U>::min();
|
||||
};
|
||||
|
||||
#define instantiate_default_limit(type) \
|
||||
@@ -71,7 +71,7 @@ inline size_t elem_to_loc(
|
||||
device const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
@@ -84,7 +84,7 @@ inline size_t elem_to_loc(
|
||||
constant const size_t* strides,
|
||||
int ndim) {
|
||||
size_t loc = 0;
|
||||
for (int i = ndim - 1; i >= 0; --i) {
|
||||
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
||||
loc += (elem % shape[i]) * strides[i];
|
||||
elem /= shape[i];
|
||||
}
|
||||
@@ -235,12 +235,42 @@ inline size_t ceildiv(size_t N, size_t M) {
|
||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||
inline float log1p(float x) {
|
||||
float xp1 = 1.0f + x;
|
||||
return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||
if (xp1 == Limits<float>::max) {
|
||||
return Limits<float>::max;
|
||||
}
|
||||
if (xp1 == 1.0f) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return x * (metal::log(xp1) / (xp1 - 1.0f));
|
||||
}
|
||||
|
||||
inline bfloat16_t log1p(bfloat16_t x) {
|
||||
float xp1 = 1.0f + static_cast<float>(x);
|
||||
bfloat16_t ret =
|
||||
(xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
return ret;
|
||||
if (xp1 == Limits<float>::max) {
|
||||
return Limits<bfloat16_t>::max;
|
||||
}
|
||||
if (xp1 == 1.0f) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SIMD shuffle ops
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
||||
return as_type<uint64_t>(
|
||||
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
||||
return as_type<int64_t>(
|
||||
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
||||
}
|
||||
|
||||
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
||||
}
|
||||
|
28
mlx/backend/metal/make_compiled_preamble.sh
Normal file
28
mlx/backend/metal/make_compiled_preamble.sh
Normal file
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script generates a C++ function that provides the Metal unary and binary
|
||||
# ops at runtime for use with kernel generation.
|
||||
#
|
||||
# Copyright © 2023-24 Apple Inc.
|
||||
|
||||
|
||||
OUTPUT_FILE=$1
|
||||
CC=$2
|
||||
SRCDIR=$3
|
||||
|
||||
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null)
|
||||
|
||||
cat << EOF > "$OUTPUT_FILE"
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* get_kernel_preamble() {
|
||||
return R"preamble(
|
||||
$CONTENT
|
||||
)preamble";
|
||||
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
EOF
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -615,7 +615,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
assert(inputs.size() == 3);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
throw std::runtime_error(
|
||||
"[matmul] Does not yet support non-floating point types.");
|
||||
|
@@ -63,15 +63,32 @@ std::function<void()> make_task(
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
auto outputs = arr.outputs();
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.push_back(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.push_back(s.data_shared_ptr());
|
||||
}
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, arr, p = std::move(p)](MTL::CommandBuffer* cbuf) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
[s, buffers = std::move(buffers), p = std::move(p)](
|
||||
MTL::CommandBuffer* cbuf) {
|
||||
p->set_value();
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
@@ -79,10 +96,7 @@ std::function<void()> make_task(
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, arr](MTL::CommandBuffer* cbuf) mutable {
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
check_error(cbuf);
|
||||
});
|
||||
}
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -27,8 +27,8 @@ void binary_op(
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
@@ -60,7 +60,7 @@ void binary_op(
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
@@ -69,8 +69,14 @@ void binary_op(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
// otherwise it goes to the second output
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
set_array_buffer(compute_encoder, donate_a ? outputs[0] : a, 0);
|
||||
set_array_buffer(
|
||||
compute_encoder, donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
||||
set_array_buffer(compute_encoder, outputs[0], 2);
|
||||
set_array_buffer(compute_encoder, outputs[1], 3);
|
||||
|
||||
@@ -122,7 +128,7 @@ void binary_op(
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -152,7 +158,7 @@ void binary_op(
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (bopt == General && shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
|
||||
@@ -161,8 +167,10 @@ void binary_op(
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, a, 0);
|
||||
set_array_buffer(compute_encoder, b, 1);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
set_array_buffer(compute_encoder, donate_a ? out : a, 0);
|
||||
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
if (bopt == General) {
|
||||
@@ -212,11 +220,15 @@ void unary_op(
|
||||
auto& in = inputs[0];
|
||||
bool contig = in.flags().contiguous;
|
||||
if (contig) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
@@ -240,7 +252,8 @@ void unary_op(
|
||||
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(
|
||||
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
if (!contig) {
|
||||
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);
|
||||
@@ -473,6 +486,18 @@ void Cosh::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op(inputs, out, "cosh");
|
||||
}
|
||||
|
||||
void CustomVJP::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Depends::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
eval(inputs, outputs);
|
||||
}
|
||||
|
||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
}
|
||||
@@ -769,4 +794,10 @@ void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void QRF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
@@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int bo = std::min(32, O);
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, w, 0);
|
||||
set_array_buffer(compute_encoder, scales, 1);
|
||||
|
@@ -28,35 +28,40 @@ inline auto safe_divup(size_t n, size_t m) {
|
||||
return safe_div(n, m) * m;
|
||||
}
|
||||
|
||||
inline bool is_64b_int(Dtype dtype) {
|
||||
return dtype == int64 || dtype == uint64;
|
||||
}
|
||||
|
||||
// All Reduce
|
||||
void all_reduce_dispatch(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::string& op_name,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
// Get kernel and encode buffers
|
||||
size_t in_size = in.size();
|
||||
auto kernel = d.get_kernel("all_reduce_" + op_name + type_to_name(in));
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||
auto kernel = (is_out_64b_int)
|
||||
? d.get_kernel("all_reduce_no_atomics_" + op_name + type_to_name(in))
|
||||
: d.get_kernel("all_reduce_" + op_name + type_to_name(in));
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
|
||||
// Set grid dimensions
|
||||
|
||||
// We make sure each thread has enough to do by making it read in
|
||||
// at least n_reads inputs
|
||||
int n_reads = REDUCE_N_READS;
|
||||
size_t in_size = in.size();
|
||||
|
||||
// mod_in_size gives us the groups of n_reads needed to go over the entire
|
||||
// input
|
||||
uint mod_in_size = (in_size + n_reads - 1) / n_reads;
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
|
||||
uint simd_size = kernel->threadExecutionWidth();
|
||||
thread_group_size =
|
||||
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
|
||||
|
||||
// If the number of thread groups needed exceeds 1024, we reuse threads groups
|
||||
uint n_thread_groups = safe_div(mod_in_size, thread_group_size);
|
||||
@@ -66,7 +71,52 @@ void all_reduce_dispatch(
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
// Encode buffers and dispatch
|
||||
if (is_out_64b_int == false || n_thread_groups == 1) {
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Allocate intermediate array to store partial reduction results
|
||||
size_t intermediate_size = n_thread_groups;
|
||||
array intermediate =
|
||||
array({static_cast<int>(intermediate_size)}, out_dtype, nullptr, {});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
std::vector<array> intermediates = {intermediate};
|
||||
|
||||
// First dispatch
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, intermediate, 1);
|
||||
compute_encoder->setBytes(&in_size, sizeof(size_t), 2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Second pass to reduce intermediate reduction results written to DRAM
|
||||
set_array_buffer(compute_encoder, intermediate, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&intermediate_size, sizeof(size_t), 2);
|
||||
|
||||
mod_in_size = (intermediate_size + n_reads - 1) / n_reads;
|
||||
|
||||
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
mod_in_size > thread_group_size ? thread_group_size : mod_in_size;
|
||||
thread_group_size =
|
||||
((thread_group_size + simd_size - 1) / simd_size) * simd_size;
|
||||
|
||||
// If the number of thread groups needed exceeds 1024, we reuse threads
|
||||
// groups
|
||||
nthreads = thread_group_size;
|
||||
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[intermediates](MTL::CommandBuffer*) mutable {
|
||||
intermediates.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void row_reduce_general_dispatch(
|
||||
@@ -76,22 +126,31 @@ void row_reduce_general_dispatch(
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
auto kernel =
|
||||
d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||
auto kernel = (is_out_64b_int)
|
||||
? d.get_kernel(
|
||||
"row_reduce_general_no_atomics_" + op_name + type_to_name(in))
|
||||
: d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
int n_reads = REDUCE_N_READS;
|
||||
size_t reduction_size = plan.shape.back();
|
||||
size_t out_size = out.size();
|
||||
auto shape = plan.shape;
|
||||
auto strides = plan.strides;
|
||||
|
||||
shape.pop_back();
|
||||
strides.pop_back();
|
||||
|
||||
size_t non_row_reductions = 1;
|
||||
for (auto s : shape) {
|
||||
non_row_reductions *= static_cast<size_t>(s);
|
||||
}
|
||||
size_t out_size = out.size();
|
||||
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
|
||||
for (auto s : rem_shape) {
|
||||
shape.push_back(s);
|
||||
@@ -101,16 +160,6 @@ void row_reduce_general_dispatch(
|
||||
}
|
||||
int ndim = shape.size();
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
|
||||
// Each thread group is responsible for 1 output
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
@@ -127,7 +176,88 @@ void row_reduce_general_dispatch(
|
||||
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
if (is_out_64b_int == false || non_row_reductions == 1) {
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Allocate intermediate array to store partial reduction results
|
||||
array intermediate = array(
|
||||
{static_cast<int>(out.size()), static_cast<int>(non_row_reductions)},
|
||||
out_dtype,
|
||||
nullptr,
|
||||
{});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
std::vector<array> intermediates = {intermediate};
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, intermediate, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Set up second dispatch
|
||||
reduction_size = non_row_reductions;
|
||||
out_size = 1;
|
||||
|
||||
// Shape of axes that aren't participating in reduction remains unchanged.
|
||||
std::vector<int> new_shape = rem_shape;
|
||||
|
||||
// Update their strides since they'll be different post partial reduction in
|
||||
// first compute dispatch.
|
||||
std::vector<size_t> new_strides = rem_strides;
|
||||
new_strides.back() = reduction_size;
|
||||
for (int i = new_shape.size() - 2; i >= 0; i--) {
|
||||
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
|
||||
}
|
||||
ndim = new_shape.size();
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, intermediate, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(
|
||||
new_shape.data(), new_shape.size() * sizeof(int), 4);
|
||||
compute_encoder->setBytes(
|
||||
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
|
||||
// Each thread group is responsible for 1 output
|
||||
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
|
||||
|
||||
// Align thread group size with simd_size
|
||||
thread_group_size =
|
||||
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
n_threads = thread_group_size;
|
||||
grid_dims = MTL::Size(n_threads, out.size(), 1);
|
||||
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[intermediates](MTL::CommandBuffer*) mutable {
|
||||
intermediates.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void strided_reduce_general_dispatch(
|
||||
@@ -137,9 +267,16 @@ void strided_reduce_general_dispatch(
|
||||
const ReductionPlan& plan,
|
||||
const std::vector<int>& axes,
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
metal::Device& d) {
|
||||
auto kernel =
|
||||
d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
Dtype out_dtype = out.dtype();
|
||||
bool is_out_64b_int = is_64b_int(out_dtype);
|
||||
auto kernel = (is_out_64b_int)
|
||||
? d.get_kernel(
|
||||
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
|
||||
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Prepare the arguments for the kernel
|
||||
size_t reduction_size = plan.shape.back();
|
||||
@@ -162,19 +299,7 @@ void strided_reduce_general_dispatch(
|
||||
}
|
||||
int ndim = shape.size();
|
||||
|
||||
// Set the arguments for the kernel
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
|
||||
// Select block dimensions
|
||||
|
||||
// Each thread reads 16 inputs to give it more work
|
||||
uint n_inputs_per_thread = REDUCE_N_READS;
|
||||
uint n_threads_per_output =
|
||||
@@ -183,14 +308,22 @@ void strided_reduce_general_dispatch(
|
||||
// We spread outputs over the x dimension and inputs over the y dimension
|
||||
// Threads with the same lid.x in a given threadgroup work on the same
|
||||
// output and each thread in the y dimension accumulates for that output
|
||||
|
||||
// Threads with same lid.x, i.e. each column of threads work on same output
|
||||
uint threadgroup_dim_x = std::min(out_size, 128ul);
|
||||
|
||||
// Number of threads along y, is dependent on number of reductions needed.
|
||||
uint threadgroup_dim_y =
|
||||
kernel->maxTotalThreadsPerThreadgroup() / threadgroup_dim_x;
|
||||
threadgroup_dim_y = std::min(n_threads_per_output, threadgroup_dim_y);
|
||||
|
||||
// Derive number of thread groups along x, based on how many threads we need
|
||||
// along x
|
||||
uint n_threadgroups_x =
|
||||
(out_size + threadgroup_dim_x - 1) / threadgroup_dim_x;
|
||||
|
||||
// Derive number of thread groups along y based on how many threads we need
|
||||
// along y
|
||||
uint n_threadgroups_y =
|
||||
(n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y;
|
||||
|
||||
@@ -199,18 +332,122 @@ void strided_reduce_general_dispatch(
|
||||
MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions);
|
||||
MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1);
|
||||
|
||||
// We set shared memory to be exploited here for reductions within a
|
||||
// threadgroup - each thread must be able to update its accumulated output
|
||||
// Note: Each threadgroup should have 32kB of data in threadgroup memory
|
||||
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
|
||||
// This should be fine for floats, but we might need to revisit
|
||||
// if we ever come to doubles. In that case, we should also cut
|
||||
// down the number of threads we launch in a threadgroup
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
|
||||
0);
|
||||
if (is_out_64b_int == false) {
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
// We set shared memory to be exploited here for reductions within a
|
||||
// threadgroup - each thread must be able to update its accumulated output
|
||||
// Note: Each threadgroup should have 32kB of data in threadgroup memory
|
||||
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
|
||||
// This should be fine for floats, but we might need to revisit
|
||||
// if we ever come to doubles. In that case, we should also cut
|
||||
// down the number of threads we launch in a threadgroup
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
|
||||
0);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
} else {
|
||||
// Allocate intermediate array to store reduction results from all thread
|
||||
// groups
|
||||
array intermediate = array(
|
||||
{static_cast<int>(out.size()),
|
||||
static_cast<int>(n_threadgroups_y * non_col_reductions)},
|
||||
out_dtype,
|
||||
nullptr,
|
||||
{});
|
||||
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
|
||||
std::vector<array> intermediates = {intermediate};
|
||||
|
||||
// Set the arguments for the kernel
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, intermediate, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
|
||||
compute_encoder->setBytes(
|
||||
strides.data(), strides.size() * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
|
||||
// We set shared memory to be exploited here for reductions within a
|
||||
// threadgroup - each thread must be able to update its accumulated output
|
||||
// Note: Each threadgroup should have 32kB of data in threadgroup memory
|
||||
// and threadgroup_dim_x * threadgroup_dim_y <= 1024 by design
|
||||
// This should be fine for floats, but we might need to revisit
|
||||
// if we ever come to doubles. In that case, we should also cut
|
||||
// down the number of threads we launch in a threadgroup
|
||||
compute_encoder->setThreadgroupMemoryLength(
|
||||
safe_divup(threadgroup_dim_x * threadgroup_dim_y * out.itemsize(), 16),
|
||||
0);
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
|
||||
// Perform second pass of reductions
|
||||
// Reduce results of threadgroups along y, z from first pass, that
|
||||
// collectively work on each output element.
|
||||
reduction_size = n_threadgroups_y * non_col_reductions;
|
||||
out_size = 1;
|
||||
|
||||
// Shape of axes that aren't participating in reduction remains unchanged.
|
||||
std::vector<int> new_shape = rem_shape;
|
||||
|
||||
// Update their strides since they'll be different after a partial reduction
|
||||
// post first compute dispatch.
|
||||
std::vector<size_t> new_strides = rem_strides;
|
||||
new_strides.back() = reduction_size;
|
||||
for (int i = new_shape.size() - 2; i >= 0; i--) {
|
||||
new_strides[i] = new_shape[i + 1] * new_strides[i + 1];
|
||||
}
|
||||
ndim = new_shape.size();
|
||||
|
||||
auto row_reduce_kernel = d.get_kernel(
|
||||
"row_reduce_general_no_atomics_" + op_name +
|
||||
type_to_name(intermediate));
|
||||
compute_encoder->setComputePipelineState(row_reduce_kernel);
|
||||
set_array_buffer(compute_encoder, intermediate, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(
|
||||
new_shape.data(), new_shape.size() * sizeof(int), 4);
|
||||
compute_encoder->setBytes(
|
||||
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
|
||||
// Each thread group is responsible for 1 output
|
||||
size_t n_reads = REDUCE_N_READS;
|
||||
size_t thread_group_size =
|
||||
row_reduce_kernel->maxTotalThreadsPerThreadgroup();
|
||||
thread_group_size =
|
||||
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
|
||||
|
||||
// Align thread group size with simd_size
|
||||
uint simd_size = row_reduce_kernel->threadExecutionWidth();
|
||||
thread_group_size =
|
||||
(thread_group_size + simd_size - 1) / simd_size * simd_size;
|
||||
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Launch enough thread groups for each output
|
||||
uint n_threads = thread_group_size;
|
||||
grid_dims = MTL::Size(n_threads, out.size(), 1);
|
||||
group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[intermediates](MTL::CommandBuffer*) mutable {
|
||||
intermediates.clear();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -223,14 +460,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
array in = inputs[0];
|
||||
|
||||
// TODO: Allow specific row and column reductions with types disabled
|
||||
// due to atomics ?
|
||||
if (size_of(in.dtype()) == 8) {
|
||||
std::ostringstream msg;
|
||||
msg << "[Reduce::eval_gpu] Does not support " << in.dtype();
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Make sure no identity reductions trickle down here
|
||||
assert(!axes_.empty());
|
||||
|
||||
@@ -297,7 +526,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Reducing over everything and the data is all there no broadcasting or
|
||||
// slicing etc.
|
||||
if (plan.type == ContiguousAllReduce) {
|
||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d);
|
||||
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
// At least the last dimension is row contiguous and we are reducing over
|
||||
@@ -305,7 +534,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
else if (
|
||||
plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
|
||||
row_reduce_general_dispatch(
|
||||
in, out, op_name, plan, axes_, compute_encoder, d);
|
||||
in, out, op_name, plan, axes_, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
// At least the last two dimensions are contiguous and we are doing a
|
||||
@@ -314,7 +543,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
plan.type == ContiguousStridedReduce ||
|
||||
plan.type == GeneralStridedReduce) {
|
||||
strided_reduce_general_dispatch(
|
||||
in, out, op_name, plan, axes_, compute_encoder, d);
|
||||
in, out, op_name, plan, axes_, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
if (!copies.empty()) {
|
||||
|
55
mlx/backend/metal/rope.cpp
Normal file
55
mlx/backend/metal/rope.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
void RoPE::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
if (in.ndim() != 3) {
|
||||
throw std::runtime_error(
|
||||
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
|
||||
}
|
||||
if (dims_ != in.shape(-1)) {
|
||||
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
|
||||
}
|
||||
if (in.flags().row_contiguous && in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
float base = std::log2(base_);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, donated ? out : in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
|
||||
compute_encoder->setBytes(&offset_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&base, sizeof(float), 4);
|
||||
compute_encoder->setBytes(&scale_, sizeof(float), 5);
|
||||
|
||||
int dim0 = in.shape(2) / 2;
|
||||
int dim1 = in.shape(1);
|
||||
int dim2 = in.shape(0);
|
||||
auto group_dims = get_block_dims(dim0, dim1, dim2);
|
||||
auto grid_dims = MTL::Size(dim0, dim1, dim2);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
@@ -22,7 +22,12 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
if (x.strides()[x.ndim() - 1] == 1) {
|
||||
bool no_copy = x.strides()[x.ndim() - 1] == 1;
|
||||
if (x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
|
@@ -9,20 +9,6 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* compute_encoder,
|
||||
MTL::ArgumentEncoder* enc,
|
||||
const array& a,
|
||||
int idx) {
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
enc->setBuffer(a_buf, offset, idx);
|
||||
// MTL::Resource usage through argument buffer needs to be explicitly
|
||||
// flagged to enable hazard tracking
|
||||
compute_encoder->useResource(a_buf, MTL::ResourceUsageRead);
|
||||
}
|
||||
|
||||
void set_array_buffer(
|
||||
MTL::ComputeCommandEncoder* enc,
|
||||
const array& a,
|
||||
@@ -117,16 +103,18 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
|
||||
// When multiple arrays are passed they should all have the same shape. The
|
||||
// collapsed axes are also the same so one shape is returned.
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
collapse_contiguous_dims(
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<std::vector<size_t>> strides) {
|
||||
// Make a vector that has axes separated with -1. Collapse all axes between
|
||||
// -1.
|
||||
std::vector<int> to_collapse;
|
||||
if (xs[0].ndim() > 0) {
|
||||
if (shape.size() > 0) {
|
||||
to_collapse.push_back(0);
|
||||
for (int i = 1; i < xs[0].ndim(); i++) {
|
||||
for (int i = 1; i < shape.size(); i++) {
|
||||
bool contiguous = true;
|
||||
for (auto& x : xs) {
|
||||
if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) {
|
||||
for (const std::vector<size_t>& st : strides) {
|
||||
if (st[i] * shape[i] != st[i - 1]) {
|
||||
contiguous = false;
|
||||
}
|
||||
if (!contiguous) {
|
||||
@@ -142,21 +130,31 @@ collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<std::vector<size_t>> out_strides(xs.size());
|
||||
std::vector<std::vector<size_t>> out_strides(strides.size());
|
||||
for (int i = 0; i < to_collapse.size(); i++) {
|
||||
int current_shape = xs[0].shape()[to_collapse[i]];
|
||||
int current_shape = shape[to_collapse[i]];
|
||||
while (to_collapse[++i] != -1) {
|
||||
current_shape *= xs[0].shape()[to_collapse[i]];
|
||||
current_shape *= shape[to_collapse[i]];
|
||||
}
|
||||
out_shape.push_back(current_shape);
|
||||
for (int j = 0; j < xs.size(); j++) {
|
||||
out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]);
|
||||
for (int j = 0; j < strides.size(); j++) {
|
||||
const std::vector<size_t>& st = strides[j];
|
||||
out_strides[j].push_back(st[to_collapse[i - 1]]);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(out_shape, out_strides);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(const std::vector<array>& xs) {
|
||||
std::vector<std::vector<size_t>> strides;
|
||||
for (auto& x : xs) {
|
||||
strides.emplace_back(x.strides());
|
||||
}
|
||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||
}
|
||||
|
||||
template <typename... Arrays>
|
||||
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
||||
collapse_contiguous_dims(Arrays... xs) {
|
||||
|
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/fast.h"
|
||||
|
||||
#define NO_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
@@ -32,12 +33,16 @@ NO_GPU(AsType)
|
||||
NO_GPU(AsStrided)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Concatenate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Copy)
|
||||
NO_GPU(Cos)
|
||||
NO_GPU(Cosh)
|
||||
NO_GPU_MULTI(CustomVJP)
|
||||
NO_GPU_MULTI(Depends)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(Remainder)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
@@ -67,6 +72,7 @@ NO_GPU(NotEqual)
|
||||
NO_GPU(Pad)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU(Power)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Reduce)
|
||||
@@ -89,6 +95,9 @@ NO_GPU(Subtract)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
NO_GPU_MULTI(DivMod)
|
||||
|
||||
namespace fast {
|
||||
NO_GPU_MULTI(RoPE)
|
||||
} // namespace fast
|
||||
|
||||
} // namespace mlx::core
|
||||
|
774
mlx/compile.cpp
Normal file
774
mlx/compile.cpp
Normal file
@@ -0,0 +1,774 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/compile.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int max_compile_depth = 10;
|
||||
|
||||
bool is_unary(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||
|
||||
typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||
|
||||
typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
|
||||
typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
|
||||
typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
|
||||
typeid(p) == typeid(Cosh) || typeid(p) == typeid(Remainder) ||
|
||||
typeid(p) == typeid(Erf) || typeid(p) == typeid(ErfInv) ||
|
||||
typeid(p) == typeid(Exp) || typeid(p) == typeid(Floor) ||
|
||||
typeid(p) == typeid(Log) || typeid(p) == typeid(Log1p) ||
|
||||
typeid(p) == typeid(LogicalNot) || typeid(p) == typeid(Negative) ||
|
||||
typeid(p) == typeid(Round) || typeid(p) == typeid(Sigmoid) ||
|
||||
typeid(p) == typeid(Sign) || typeid(p) == typeid(Sin) ||
|
||||
typeid(p) == typeid(Sinh) || typeid(p) == typeid(Square) ||
|
||||
typeid(p) == typeid(Sqrt) || typeid(p) == typeid(Tan) ||
|
||||
typeid(p) == typeid(Tanh));
|
||||
}
|
||||
|
||||
bool is_binary(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||
|
||||
typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||
|
||||
typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||
|
||||
typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||
|
||||
typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||
|
||||
typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
|
||||
typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
|
||||
typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
|
||||
typeid(p) == typeid(Subtract));
|
||||
}
|
||||
|
||||
bool is_broadcast(const Primitive& p) {
|
||||
return typeid(p) == typeid(Broadcast);
|
||||
}
|
||||
|
||||
bool is_noop(const Primitive& p) {
|
||||
return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
|
||||
}
|
||||
|
||||
bool is_fusable(const Primitive& p) {
|
||||
return is_unary(p) || is_binary(p) || is_broadcast(p) || is_noop(p);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
std::vector<array> compile_replace(
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
Compiled::Compiled(
|
||||
Stream stream,
|
||||
std::vector<array> inputs,
|
||||
std::vector<array> outputs,
|
||||
std::vector<array> tape,
|
||||
std::unordered_set<uintptr_t> constant_ids)
|
||||
: Primitive(stream),
|
||||
inputs_(std::move(inputs)),
|
||||
outputs_(std::move(outputs)),
|
||||
tape_(std::move(tape)),
|
||||
constant_ids_(std::move(constant_ids)) {}
|
||||
|
||||
std::vector<array> Compiled::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[Compiled] Cannot vjp primitive.");
|
||||
}
|
||||
|
||||
std::vector<array> Compiled::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
throw std::runtime_error("[Compiled] Cannot jvp primitive.");
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("[Compiled] Cannot vmap primitive.");
|
||||
}
|
||||
|
||||
bool Compiled::is_equivalent(const Primitive& other) const {
|
||||
const Compiled& a_other = static_cast<const Compiled&>(other);
|
||||
return std::equal(
|
||||
tape_.begin(),
|
||||
tape_.end(),
|
||||
a_other.tape_.begin(),
|
||||
a_other.tape_.end(),
|
||||
[](const array& a1, const array& a2) {
|
||||
auto& p1 = a1.primitive();
|
||||
auto& p2 = a2.primitive();
|
||||
return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);
|
||||
});
|
||||
}
|
||||
|
||||
void Compiled::print(std::ostream& os) {
|
||||
os << "Compiled";
|
||||
for (auto& a : tape_) {
|
||||
a.primitive().print(os);
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
CompileMode& compile_mode() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_DISABLE_COMPILE")) {
|
||||
return CompileMode::disabled;
|
||||
} else {
|
||||
return CompileMode::enabled;
|
||||
}
|
||||
};
|
||||
static CompileMode compile_mode_ = get_val();
|
||||
return compile_mode_;
|
||||
}
|
||||
|
||||
using CompileFn = std::function<std::vector<array>(const std::vector<array>&)>;
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
||||
// Helper that merges two arrays in the graph by setting the parents of the
|
||||
// source to point to the destination
|
||||
void merge(array& dst, array& src, ParentsMap& parents_map) {
|
||||
// Canonicalize the order of the primitives outputs
|
||||
auto sources = src.outputs();
|
||||
auto dests = dst.outputs();
|
||||
// For each src parent, point it to the corresponding dst
|
||||
for (int i = 0; i < sources.size(); ++i) {
|
||||
auto src_parents = parents_map.find(sources[i].id());
|
||||
if (src_parents == parents_map.end()) {
|
||||
continue;
|
||||
}
|
||||
auto& pairs = parents_map[dests[i].id()];
|
||||
for (auto& parent : src_parents->second) {
|
||||
parent.first.inputs()[parent.second] = dests[i];
|
||||
pairs.push_back(parent);
|
||||
}
|
||||
// Remove the source from the map to avoid fusing with it again
|
||||
parents_map.erase(src_parents);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename... U>
|
||||
size_t getAddress(std::function<T(U...)> f) {
|
||||
typedef T(fnType)(U...);
|
||||
fnType** fnPointer = f.template target<fnType*>();
|
||||
if (fnPointer == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"[compile] Cannot compile a non-addressable function.");
|
||||
}
|
||||
return (size_t)*fnPointer;
|
||||
}
|
||||
|
||||
struct CompilerCache {
|
||||
struct CacheEntry {
|
||||
std::vector<array> inputs;
|
||||
std::vector<array> outputs;
|
||||
std::vector<array> tape;
|
||||
bool empty{true};
|
||||
};
|
||||
|
||||
// Returns a reference to a CacheEntry which can be updated
|
||||
// by the caller to avoid copying large tapes / inputs / outputs
|
||||
CacheEntry& find(size_t fun_id, const std::vector<array>& inputs) {
|
||||
// Try to find the entry
|
||||
auto [entry_it, inserted] = cache_.insert({fun_id, {}});
|
||||
auto& entries = entry_it->second;
|
||||
auto is_match = [](const std::vector<array>& in1,
|
||||
const std::vector<array>& in2) {
|
||||
if (in1.size() != in2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < in1.size(); ++i) {
|
||||
if (in1[i].shape() != in2[i].shape()) {
|
||||
return false;
|
||||
}
|
||||
if (in1[i].dtype() != in2[i].dtype()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// Loop over entries and check inputs match i.e. shapes and types must be
|
||||
// equal. Note this could get really slow if one compiles the same
|
||||
// function with many different shapes. May want to store entries in a
|
||||
// more easily searchable structure.
|
||||
for (auto& entry : entries) {
|
||||
// Check the inputs match and return if so
|
||||
if (is_match(inputs, entry.inputs)) {
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
// Otherwise append a new cache entry
|
||||
entries.push_back(CacheEntry{});
|
||||
return entries.back();
|
||||
};
|
||||
|
||||
void erase(size_t fun_id) {
|
||||
cache_.erase(fun_id);
|
||||
}
|
||||
|
||||
private:
|
||||
CompilerCache() {
|
||||
// Make sure the allocator is fully
|
||||
// initialized before the compiler cache
|
||||
allocator::allocator();
|
||||
}
|
||||
friend CompilerCache& compiler_cache();
|
||||
std::unordered_map<size_t, std::vector<CacheEntry>> cache_;
|
||||
};
|
||||
|
||||
CompilerCache& compiler_cache() {
|
||||
static CompilerCache compiler_cache_;
|
||||
return compiler_cache_;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
|
||||
// Run the function on placeholder inputs
|
||||
// to get compute graph
|
||||
std::vector<array> tracer_inputs;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {});
|
||||
in.set_tracer(true);
|
||||
tracer_inputs.push_back(std::move(in));
|
||||
}
|
||||
return {tracer_inputs, fun(tracer_inputs)};
|
||||
}
|
||||
|
||||
// Traverses the graph to build a tape and a map of array ids to their parents
|
||||
std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::vector<array> tape;
|
||||
std::unordered_set<std::uintptr_t> input_set;
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||
parents_map;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto in = inputs[i];
|
||||
input_set.insert(in.id());
|
||||
}
|
||||
|
||||
// DFS the graph to build the tape, and log parents and scalars
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
recurse = [&](const array& a) {
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < a.inputs().size(); i++) {
|
||||
auto& in = a.inputs()[i];
|
||||
parents_map[in.id()].push_back({a, i});
|
||||
for (auto& s : a.siblings()) {
|
||||
parents_map[in.id()].push_back({s, i});
|
||||
}
|
||||
// Don't recurse on inputs (but add them to the tape for the purpose
|
||||
// of future optimizations)
|
||||
if (input_set.find(a.id()) == input_set.end()) {
|
||||
recurse(in);
|
||||
}
|
||||
}
|
||||
cache.insert(id);
|
||||
for (auto& s : a.siblings()) {
|
||||
cache.insert(s.id());
|
||||
}
|
||||
tape.push_back(a);
|
||||
};
|
||||
for (auto& a : outputs) {
|
||||
recurse(a);
|
||||
}
|
||||
return {tape, parents_map};
|
||||
}
|
||||
|
||||
// Simplify the tape. Note, this function modifies in-place both the tape and
|
||||
// the parents map to remove orphaned arrays
|
||||
void compile_simplify(
|
||||
std::vector<array>& tape,
|
||||
ParentsMap& parents_map,
|
||||
const std::vector<array>& outputs,
|
||||
int passes) {
|
||||
// Helpers to identify identical scalars
|
||||
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
||||
auto is_scalar = [](const array& a) {
|
||||
return a.is_evaled() && a.ndim() == 0;
|
||||
};
|
||||
auto get_scalar_rep = [](const array& a) {
|
||||
uint64_t v = 0;
|
||||
int dtype;
|
||||
switch (a.dtype().size) {
|
||||
case 1:
|
||||
v = *a.data<uint8_t>();
|
||||
break;
|
||||
case 4:
|
||||
v = *a.data<uint32_t>();
|
||||
break;
|
||||
case 8:
|
||||
v = *a.data<uint64_t>();
|
||||
break;
|
||||
}
|
||||
return std::make_pair(v, a.dtype().val);
|
||||
};
|
||||
|
||||
for (auto& a : tape) {
|
||||
if (is_scalar(a)) {
|
||||
scalars.insert({get_scalar_rep(a), a});
|
||||
}
|
||||
}
|
||||
|
||||
// Depth-1 array equivalence check.
|
||||
auto array_equivalent = [](const array& a, const array& b) {
|
||||
if (!a.has_primitive() || !b.has_primitive()) {
|
||||
return false;
|
||||
}
|
||||
if (a.primitive_id() == b.primitive_id()) {
|
||||
return false;
|
||||
}
|
||||
const auto& pa = a.primitive();
|
||||
const auto& pb = b.primitive();
|
||||
if (typeid(pa) != typeid(pb)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (a.inputs().size() != b.inputs().size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < a.inputs().size(); i++) {
|
||||
if (a.inputs()[i].id() != b.inputs()[i].id()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return pa.is_equivalent(pb);
|
||||
};
|
||||
|
||||
// Merge scalars
|
||||
std::vector<array> new_tape;
|
||||
for (auto& arr : tape) {
|
||||
// Check if we can merge scalars
|
||||
if (is_scalar(arr)) {
|
||||
auto scalar = scalars.find(get_scalar_rep(arr));
|
||||
if (scalar->second.id() != arr.id()) {
|
||||
merge(scalar->second, arr, parents_map);
|
||||
// Don't keep orphaned scalars in the tape
|
||||
continue;
|
||||
}
|
||||
}
|
||||
new_tape.push_back(std::move(arr));
|
||||
}
|
||||
tape = std::move(new_tape);
|
||||
|
||||
std::unordered_set<uintptr_t> output_set;
|
||||
for (auto& o : outputs) {
|
||||
output_set.insert(o.id());
|
||||
}
|
||||
// Multi-pass merge only keeping non-orphaned arrays in the tape
|
||||
for (int pass = 0; pass < passes; ++pass) {
|
||||
for (auto& arr : tape) {
|
||||
// Helper to check if we can merge the parents of the
|
||||
// given array
|
||||
auto maybe_merge_parents = [&](auto& a) {
|
||||
auto parents = parents_map.find(a.id());
|
||||
if (parents != parents_map.end()) {
|
||||
auto N = parents->second.size();
|
||||
std::vector<bool> mask(N, false);
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (mask[i]) {
|
||||
continue;
|
||||
}
|
||||
for (int j = i + 1; j < N; j++) {
|
||||
if (mask[j]) {
|
||||
continue;
|
||||
}
|
||||
auto& src = parents->second[j].first;
|
||||
auto& dst = parents->second[i].first;
|
||||
if (src.id() != dst.id() && array_equivalent(src, dst)) {
|
||||
merge(dst, src, parents_map);
|
||||
mask[j] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Erase orphaned parents so we don't keep fusing with them
|
||||
for (int i = N - 1; i > 0; --i) {
|
||||
if (mask[i]) {
|
||||
parents->second.erase(parents->second.begin() + i);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
return output_set.find(a.id()) == output_set.end();
|
||||
}
|
||||
};
|
||||
|
||||
bool discard = maybe_merge_parents(arr);
|
||||
for (auto& s : arr.siblings()) {
|
||||
discard &= maybe_merge_parents(s);
|
||||
}
|
||||
// If an array and its siblings have no parents, and none of them are
|
||||
// outputs, it is safe to remove it from the tape
|
||||
if (!discard) {
|
||||
new_tape.push_back(std::move(arr));
|
||||
}
|
||||
}
|
||||
tape = std::move(new_tape);
|
||||
}
|
||||
}
|
||||
|
||||
// Extract sub-graphs of the graph that can be compiled
|
||||
// and replace them with a Compiled Primitive.
|
||||
void compile_fuse(
|
||||
std::vector<array>& tape,
|
||||
ParentsMap& parents_map,
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
// Track outputs to replace with new compiled outputs
|
||||
std::unordered_map<uintptr_t, array> output_map;
|
||||
for (auto& o : outputs) {
|
||||
output_map.insert({o.id(), o});
|
||||
}
|
||||
|
||||
// Set of inputs to distinguish constants
|
||||
std::unordered_set<uintptr_t> input_ids;
|
||||
for (auto& in : inputs) {
|
||||
input_ids.insert(in.id());
|
||||
}
|
||||
|
||||
// Go through the tape in reverse order and check for fusable sub-graphs
|
||||
std::vector<array> new_tape;
|
||||
std::unordered_set<uintptr_t> global_cache;
|
||||
for (int i = tape.size() - 1; i >= 0; --i) {
|
||||
auto& arr = tape[i];
|
||||
|
||||
// Already compiled
|
||||
if (global_cache.find(arr.id()) != global_cache.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Two pass recursion:
|
||||
// First pass:
|
||||
// - Collect all the primitives which we can fuse with
|
||||
// - Keeps a cache of fusable primitives which may be added out of
|
||||
// DAG order. We have to determine if all of a fused primitive's
|
||||
// outputs are also in the fused section, and this may not be the
|
||||
// case the first time we visit it.
|
||||
// Second pass:
|
||||
// - Collect inputs to the new compiled primitive
|
||||
// - Add fusable primitives to a tape in the correct order
|
||||
|
||||
std::function<void(const array&, int, const Stream&)> recurse;
|
||||
std::unordered_set<uintptr_t> cache;
|
||||
recurse = [&](const array& a, int depth, const Stream& s) {
|
||||
if (cache.find(a.id()) != cache.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Stop fusing if:
|
||||
// - Depth limit exceeded
|
||||
// - Constant input
|
||||
// - Stream mismatch
|
||||
// - Non fusable primitive
|
||||
if (depth >= max_compile_depth || !a.has_primitive() ||
|
||||
a.primitive().stream() != s || !is_fusable(a.primitive())) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool all_parents_in = true;
|
||||
if (depth > 0) {
|
||||
// Guaranteed to have a parent since nested in the
|
||||
// recursion.
|
||||
auto& parents = parents_map.at(a.id());
|
||||
for (auto& [p, idx] : parents) {
|
||||
auto in_cache = cache.find(p.id()) != cache.end();
|
||||
if (!in_cache) {
|
||||
all_parents_in = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Arrays with a mix of parents outside the compilable section
|
||||
// are not fusable
|
||||
if (!all_parents_in) {
|
||||
return;
|
||||
}
|
||||
|
||||
cache.insert({a.id()});
|
||||
|
||||
for (auto& in : a.inputs()) {
|
||||
recurse(in, depth + 1, s);
|
||||
}
|
||||
};
|
||||
|
||||
if (arr.has_primitive()) {
|
||||
Stream s = arr.primitive().stream();
|
||||
recurse(arr, 0, s);
|
||||
}
|
||||
|
||||
// Not worth fusing a single primitive
|
||||
if (cache.size() <= 1) {
|
||||
new_tape.push_back(arr);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Recurse a second time to build the tape in the right
|
||||
// order and collect the inputs
|
||||
std::unordered_set<uintptr_t> input_set;
|
||||
std::vector<array> inputs;
|
||||
std::vector<array> fused_tape;
|
||||
std::unordered_set<uintptr_t> tape_set;
|
||||
std::function<void(const array&)> recurse_tape;
|
||||
recurse_tape = [&](const array& a) {
|
||||
if (cache.find(a.id()) == cache.end()) {
|
||||
if (input_set.find(a.id()) == input_set.end()) {
|
||||
input_set.insert(a.id());
|
||||
inputs.push_back(a);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (tape_set.find(a.id()) != tape_set.end()) {
|
||||
return;
|
||||
}
|
||||
tape_set.insert(a.id());
|
||||
for (auto& in : a.inputs()) {
|
||||
recurse_tape(in);
|
||||
}
|
||||
fused_tape.push_back(a);
|
||||
};
|
||||
recurse_tape(arr);
|
||||
|
||||
std::vector<array> old_outputs;
|
||||
// Add to global cache and add any global outputs to outputs
|
||||
// of new primitive
|
||||
for (int j = 0; j < fused_tape.size() - 1; ++j) {
|
||||
auto& f = fused_tape[j];
|
||||
if (output_map.find(f.id()) != output_map.end()) {
|
||||
old_outputs.push_back(f);
|
||||
// Parents are now siblings, update the parent map
|
||||
auto& pairs = parents_map[f.id()];
|
||||
pairs.erase(
|
||||
std::remove_if(
|
||||
pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](auto& p) {
|
||||
return cache.find(p.first.id()) != cache.end();
|
||||
}),
|
||||
pairs.end());
|
||||
} else {
|
||||
// Remove inner fused arrays parents from the parents map
|
||||
// to keep the parents map in a valid state
|
||||
parents_map.erase(f.id());
|
||||
}
|
||||
global_cache.insert({f.id()});
|
||||
}
|
||||
old_outputs.push_back(arr);
|
||||
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> types;
|
||||
for (auto& o : old_outputs) {
|
||||
shapes.push_back(o.shape());
|
||||
types.push_back(o.dtype());
|
||||
}
|
||||
std::unordered_set<uintptr_t> constant_ids;
|
||||
for (auto& in : inputs) {
|
||||
// Scalar constant
|
||||
if (in.size() == 1 && !in.has_primitive() &&
|
||||
input_ids.find(in.id()) == input_ids.end()) {
|
||||
constant_ids.insert(in.id());
|
||||
}
|
||||
}
|
||||
auto compiled_outputs = array::make_arrays(
|
||||
shapes,
|
||||
types,
|
||||
std::make_shared<Compiled>(
|
||||
old_outputs.back().primitive().stream(),
|
||||
inputs,
|
||||
old_outputs,
|
||||
std::move(fused_tape),
|
||||
std::move(constant_ids)),
|
||||
inputs);
|
||||
|
||||
// One output per primitive
|
||||
new_tape.push_back(compiled_outputs.back());
|
||||
|
||||
// Replace inputs old parents with compiled_outputs
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& pairs = parents_map[inputs[i].id()];
|
||||
pairs.erase(
|
||||
std::remove_if(
|
||||
pairs.begin(),
|
||||
pairs.end(),
|
||||
[&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),
|
||||
pairs.end());
|
||||
for (auto& o : compiled_outputs) {
|
||||
pairs.push_back({o, i});
|
||||
}
|
||||
}
|
||||
|
||||
// - Update outputs parents to point to compiled outputs
|
||||
// - Update any overall graph outputs to be compiled outputs
|
||||
for (int o = 0; o < old_outputs.size(); ++o) {
|
||||
merge(compiled_outputs[o], old_outputs[o], parents_map);
|
||||
if (auto it = output_map.find(old_outputs[o].id());
|
||||
it != output_map.end()) {
|
||||
it->second = compiled_outputs[o];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::reverse(new_tape.begin(), new_tape.end());
|
||||
tape = std::move(new_tape);
|
||||
|
||||
// Replace output with potentially compiled output
|
||||
for (auto& o : outputs) {
|
||||
o = output_map.at(o.id());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> compile_replace(
|
||||
const std::vector<array>& tape,
|
||||
const std::vector<array>& trace_inputs,
|
||||
const std::vector<array>& trace_outputs,
|
||||
const std::vector<array>& inputs) {
|
||||
std::unordered_map<uintptr_t, array> trace_to_real;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
|
||||
}
|
||||
|
||||
for (auto& a : tape) {
|
||||
// Arrays in the tape without primitives are constants
|
||||
// and can be used directly
|
||||
if (!a.has_primitive()) {
|
||||
trace_to_real.insert({a.id(), a});
|
||||
} else {
|
||||
// Find real inputs
|
||||
std::vector<array> real_inputs;
|
||||
for (auto& in : a.inputs()) {
|
||||
real_inputs.push_back(trace_to_real.at(in.id()));
|
||||
}
|
||||
if (a.siblings().empty()) {
|
||||
auto real_a = array(
|
||||
a.shape(), a.dtype(), a.primitive_ptr(), std::move(real_inputs));
|
||||
trace_to_real.insert({a.id(), std::move(real_a)});
|
||||
} else {
|
||||
// Ensure the order is correct for multi-output primitives
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Dtype> types;
|
||||
auto trace_out = a.outputs();
|
||||
for (auto& o : trace_out) {
|
||||
shapes.push_back(o.shape());
|
||||
types.push_back(o.dtype());
|
||||
}
|
||||
auto real_out =
|
||||
array::make_arrays(shapes, types, a.primitive_ptr(), real_inputs);
|
||||
for (int i = 0; i < trace_out.size(); ++i) {
|
||||
trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
for (auto& o : trace_outputs) {
|
||||
outputs.push_back(trace_to_real.at(o.id()));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
size_t fun_id) {
|
||||
if (compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
return [fun, fun_id](const std::vector<array>& inputs) {
|
||||
// If the inputs are tracers, trace the original graph
|
||||
if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
|
||||
return in.is_tracer();
|
||||
})) {
|
||||
return fun(inputs);
|
||||
}
|
||||
|
||||
// Find a cache entry with the correct inputs
|
||||
auto& entry = compiler_cache().find(fun_id, inputs);
|
||||
|
||||
// No matching cache entry existed, so compile
|
||||
if (entry.empty) {
|
||||
// Mark the entry as not empty since we are about to fill it
|
||||
entry.empty = false;
|
||||
// Trace to build the graph
|
||||
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
|
||||
|
||||
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||
// position in parent inputs)
|
||||
std::unordered_map<uintptr_t, std::vector<std::pair<array, int>>>
|
||||
parents_map;
|
||||
std::tie(entry.tape, parents_map) =
|
||||
compile_dfs(entry.inputs, entry.outputs);
|
||||
|
||||
// Simplify the tape
|
||||
if (compile_mode() != CompileMode::no_simplify) {
|
||||
compile_simplify(
|
||||
entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
||||
}
|
||||
|
||||
// Kernel fusion to generate Compiled primitives. The tape and
|
||||
// new outputs must be updated accordingly
|
||||
if (compile_mode() != CompileMode::no_fuse) {
|
||||
compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
|
||||
}
|
||||
}
|
||||
|
||||
// At this point we must have a tape, now replace the placeholders
|
||||
// with real arrays that can be evaluated
|
||||
return compile_replace(entry.tape, entry.inputs, entry.outputs, inputs);
|
||||
};
|
||||
}
|
||||
|
||||
void compile_erase(size_t fun_id) {
|
||||
detail::compiler_cache().erase(fun_id);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun) {
|
||||
if (detail::compile_mode() == CompileMode::disabled) {
|
||||
return fun;
|
||||
}
|
||||
auto fun_id = detail::getAddress(fun);
|
||||
return detail::compile(fun, fun_id);
|
||||
}
|
||||
|
||||
void disable_compile() {
|
||||
detail::compile_mode() = CompileMode::disabled;
|
||||
}
|
||||
|
||||
void enable_compile() {
|
||||
detail::compile_mode() = CompileMode::enabled;
|
||||
}
|
||||
|
||||
void set_compile_mode(CompileMode mode) {
|
||||
detail::compile_mode() = mode;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
28
mlx/compile.h
Normal file
28
mlx/compile.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
|
||||
|
||||
// Compile takes a function and returns a new function
|
||||
std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun);
|
||||
|
||||
/** Globally disable compilation.
|
||||
* Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
|
||||
* be used to disable compilation.
|
||||
*/
|
||||
void disable_compile();
|
||||
|
||||
/** Globally enable compilation.
|
||||
* This will override the environment variable ``MLX_DISABLE_COMPILE``.
|
||||
*/
|
||||
void enable_compile();
|
||||
|
||||
/** Set the compiler mode to the given value. */
|
||||
void set_compile_mode(CompileMode mode);
|
||||
} // namespace mlx::core
|
128
mlx/fast.cpp
Normal file
128
mlx/fast.cpp
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/transforms.h"
|
||||
|
||||
namespace mlx::core::fast {
|
||||
|
||||
std::vector<array> Custom::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) {
|
||||
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
|
||||
std::vector<array> vjp_outs;
|
||||
for (int i = 0, j = 0; i < vjps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
vjp_outs.push_back(vjps[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return vjp_outs;
|
||||
}
|
||||
|
||||
std::vector<array> Custom::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
|
||||
std::vector<array> jvp_outs;
|
||||
for (int i = 0, j = 0; i < jvps.size(); ++i) {
|
||||
if (i < argnums.size() && i == argnums[j]) {
|
||||
jvp_outs.push_back(jvps[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return jvp_outs;
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
auto outputs = mlx::core::vmap(fallback_, axes)(inputs);
|
||||
auto out_axes = std::vector<int>(outputs.size(), 0);
|
||||
return {outputs, out_axes};
|
||||
}
|
||||
|
||||
array rope(
|
||||
const array& x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
float base,
|
||||
float scale,
|
||||
int offset,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (x.ndim() != 3) {
|
||||
std::ostringstream msg;
|
||||
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (traditional && x.shape(-1) != dims) {
|
||||
throw std::invalid_argument(
|
||||
"[rope] Does not support partial traditional application.");
|
||||
}
|
||||
|
||||
auto fallback = [dims, traditional, base, scale, offset, s](
|
||||
const std::vector<array>& inputs) {
|
||||
auto& x = inputs[0];
|
||||
auto t = x.dtype();
|
||||
auto N = x.shape(1) + offset;
|
||||
// Compute sines and cosines
|
||||
auto half_dims = dims / 2;
|
||||
auto positions = multiply(arange(offset, N, t, s), array(scale, t), s);
|
||||
auto freqs = negative(arange(0, half_dims, t, s), s);
|
||||
freqs = exp(multiply(freqs, array(std::log(base) / half_dims, t), s), s);
|
||||
auto theta =
|
||||
multiply(expand_dims(positions, 1, s), expand_dims(freqs, 0, s), s);
|
||||
auto coss = cos(theta, s);
|
||||
auto sins = sin(theta, s);
|
||||
|
||||
if (traditional) {
|
||||
auto x1 = slice(x, {0, 0, 0}, x.shape(), {1, 1, 2}, s);
|
||||
auto x2 = slice(x, {0, 0, 1}, x.shape(), {1, 1, 2}, s);
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
for (auto& o : outs) {
|
||||
o = expand_dims(o, 3, s);
|
||||
}
|
||||
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
|
||||
} else {
|
||||
auto out_s = x.shape();
|
||||
out_s.back() = half_dims;
|
||||
auto x1 = slice(x, {0, 0, 0}, out_s, s);
|
||||
out_s.back() = dims;
|
||||
auto x2 = slice(x, {0, 0, half_dims}, out_s, s);
|
||||
|
||||
std::vector<array> outs;
|
||||
outs.push_back(subtract(multiply(x1, coss, s), multiply(x2, sins, s), s));
|
||||
outs.push_back(add(multiply(x1, sins, s), multiply(x2, coss, s), s));
|
||||
if (dims < x.shape(-1)) {
|
||||
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
|
||||
}
|
||||
return std::vector<array>{concatenate(outs, 2, s)};
|
||||
}
|
||||
};
|
||||
// TODO change to condition for using custom prim
|
||||
auto stream = to_stream(s);
|
||||
if (stream.device == Device::gpu && x.shape(-1) == dims) {
|
||||
return array(
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
std::make_unique<RoPE>(
|
||||
stream, fallback, dims, traditional, base, scale, offset),
|
||||
{x});
|
||||
}
|
||||
return fallback({x})[0];
|
||||
}
|
||||
|
||||
bool RoPE::is_equivalent(const Primitive& other) const {
|
||||
const RoPE& a_other = static_cast<const RoPE&>(other);
|
||||
return (
|
||||
dims_ == a_other.dims_ && base_ == a_other.base_ &&
|
||||
scale_ == a_other.scale_ && traditional_ == a_other.traditional_ &&
|
||||
offset_ == a_other.offset_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fast
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user