Compare commits

...

58 Commits

Author SHA1 Message Date
Awni Hannun
44390bd3d0 Bump (#869)
* bump

* fix none in a few ops
2024-03-21 13:56:56 -07:00
Angelos Katharopoulos
2225374060 Adds mx.fast.layer_norm (#870) 2024-03-21 13:55:51 -07:00
nicolov
105d236889 Add vmap for SVD and inverse (#849) 2024-03-21 13:18:27 -07:00
Angelos Katharopoulos
53e6a9367c Use reshape and transpose for non-overlapping pooling windows (#867) 2024-03-21 10:21:03 -07:00
Chime Ogbuji
f5a1582fe8 Add minimum for cosine decay function (#859)
* Add minimum for cosine decay function

* Update python/mlx/optimizers/schedulers.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-21 07:33:29 -07:00
Awni Hannun
a54f06b16f Fast RMS Norm (#862)
* fast rmsnorm

* no rms gpu

* kernel

* fix shared mem

* looped rms and donation in softmax

* Make the squaring in float32 to avoid underflow

* Fix the default StreamOrDevice for rope and rms_norm in fast

* nits

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-21 07:20:54 -07:00
Cheng
4650d94d98 Add missing && in eval (#864)
Without the && args would be copied and perfect forwarding won't work.

To avoid eval calling itself recursively, the vector version of eval is
changed to take by value instead, which will save a copy of array when a
rvalue is passed.
2024-03-21 06:15:48 -07:00
Jagrit Digani
a5681ebc52 Update set item (#861)
* Update mlx_set_item to handle regular slices without expanding

* Refactor ellipsis handling

* Route mlx_set_item to slice_update where possible

* Update mlx_scatter_args_slice

* Don't route to gather if no array indices
2024-03-21 02:48:13 -07:00
Cheng
e849b3424a Do not use static constexpr in header (#863)
Doing so results in each compilation unit (.cpp file) having its own
copy of the variable, while inline constexpr makes sure there is only
one copy.
2024-03-20 21:28:05 -07:00
Jagrit Digani
b219d12a6b Check edge case handling in row reduce med kernel (#858) 2024-03-20 11:37:58 -07:00
Jagrit Digani
cec8661113 Add a SliceUpdate op and primitive (#850)
* Enable copy to work with int64 strides
* Fix uniform buffer indices or copy kernel arguments
* Update utils.h
* Remove manual unrolling of elem to loc loop
* GPU copy updated to handle negative strides
* Add slice update primitive
2024-03-20 10:39:25 -07:00
Cheng
73a8c090e0 Pass shape and inputs by value in array's constructor (#853)
Since the shape and inputs are always saved as copy in ArrayDesc, we can
unify array's constructors to just take the arguments by value.

There are 2 cases:
1. When shape is a lvalue, it will be copied into array's constructor and
   then moved into ArrayDesc's member. So only 1 copy happens.
2. When shape is a rvalue, it will be moved into array's constructor and
   then moved into ArrayDesc's member. So no copy happens.

So having 1 constructor that takes by value is equivalent to having 2
constructors that const reference and rvalue separately.
2024-03-20 07:54:30 -07:00
Md. Rasel Mandol
db6796ac61 simple typo fille (#848) 2024-03-19 06:15:17 -07:00
Awni Hannun
9a8ee00246 Switch to nanobind (#839)
* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
2024-03-18 20:12:25 -07:00
Cheng
d39ed54f8e Some C++ code are not needed (#841)
1. Anonymous namespace means internal linkage, static keyword is not needed.
2. The default constructor of std::shared_ptr initializes the pointer to
   nullptr, you don't need to explicitly set it.
2024-03-18 17:04:10 -07:00
Awni Hannun
16546c70d8 No reshape rope (#838)
* no reshape rope

* no reshape rope
2024-03-18 17:03:07 -07:00
nicolov
eaba55c9bf Add matrix inversion primitive (#822) 2024-03-15 06:34:36 -07:00
Awni Hannun
19ec023256 vmap matmul and admm (#836) 2024-03-14 14:38:22 -07:00
Awni Hannun
63ab0ab580 version (#835) 2024-03-14 12:20:40 -07:00
Jagrit Digani
8dfc376c00 Strided reduce specialization for small reductions (#826)
* Add small column / general reduction specialization
2024-03-14 09:16:53 -07:00
Angelos Katharopoulos
1efee9db09 Add types and order in kernel name (#831) 2024-03-13 20:34:06 -07:00
Awni Hannun
43abc402d8 route to fallback (#828) 2024-03-13 19:56:04 -07:00
Angelos Katharopoulos
3f8b1668c4 Make reshape faster for row_contiguous cases (#829) 2024-03-13 16:22:03 -07:00
Angelos Katharopoulos
76c919b4ec NumberOfElements for shapeless compile and vmap fixes (#802) 2024-03-13 10:34:14 -07:00
Angelos Katharopoulos
29d0c10ee5 Reshape improvement (#818) 2024-03-12 17:54:31 -07:00
Jagrit Digani
5ad133f8bb No copy gems (#801)
* Enable collapsing batch dims in gemm
* Update gemm to only make copies when neither of the last 2 axes are contiguous
* Update addmm to support gemv shapes
* Update addmm to support irregular batch strides
* Update tests
2024-03-12 13:13:41 -07:00
nicolov
d0c544a868 Add SVD primitive (#809)
Add SVD op using Accelerate's LAPACK following
https://developer.apple.com/documentation/accelerate/
compressing_an_image_using_linear_algebra

Co-authored-by: Nicolo Valigi <nvaligi@apple.com>
2024-03-12 12:30:11 -07:00
Daniel Falbel
ffb19df3c0 Fix docstring for correctly rendering (#820) 2024-03-12 11:46:44 -07:00
Awni Hannun
8b7532b9ab fix scatter (#821) 2024-03-12 11:42:07 -07:00
Awni Hannun
366478c560 fix modules with dict (#819) 2024-03-12 08:54:06 -07:00
Justin Deschenaux
8e5600022a Implement RNN, GRU, LSTM (#268)
* RNN base implementation

* Address comments+format

* nits in docs

* add tests for prb

* fix test

* add a couple tests

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-11 21:14:44 -07:00
Awni Hannun
0e95b64942 Fix bug in tape order during simplify (#816)
* fix bug in tape order during simplify

* properly fix compile

* last bug
2024-03-11 17:29:05 -07:00
nicolov
0ae22b915b Remove code duplication in reduce ops (#793)
* Remove code duplication in reduce ops

* Remove the unnecessary lambda

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-11 10:57:07 -07:00
Awni Hannun
7c441600fe Compile stride bug (#812)
* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
2024-03-11 06:31:31 -07:00
Awni Hannun
a4d290adb9 Remove depth traversal (#813)
* no depth traversal

* counter outside loop
2024-03-09 20:21:32 -08:00
Awni Hannun
28301807c2 Version bump and os error (#807) 2024-03-07 13:57:58 -08:00
Awni Hannun
74ed0974b3 Support 13.0+ with xcode 14.3 (#806)
* Support 13.0+ with xcode 14.3

* revert revert
2024-03-07 13:27:57 -08:00
Jagrit Digani
ec8a4864fa Fix SDPA kernel bug on Mac OS 13.3 SDK (#805)
* Move sdpa kernel to allocate tgp mem statically and allow macOS 13.3 SDK builds

* Style
2024-03-07 10:18:09 -08:00
Awni Hannun
b7588fd5d7 fix inplace to not make a shallow copy (#804) 2024-03-07 09:34:11 -08:00
Awni Hannun
f512b905c7 Minimum xcode / sdk (#800)
* minimum xcode /sdk

* try multiple xcode versions in CI

* update python

* metal validation for python tests
2024-03-07 08:19:43 -08:00
Awni Hannun
afd5274049 route to fallback for bfloat (#794) 2024-03-06 15:39:12 -08:00
Awni Hannun
1074674e32 Add a maximum graph depth (#797)
* add a maximum graph depth

* remember how to use C++
2024-03-06 15:39:00 -08:00
AlexCheema
7762e07fde Update function_transforms.rst (#796)
Fix typo in function_transforms.rst
2024-03-06 12:03:37 -08:00
Luca Arnaboldi
cbefd9129e Implementation of pickle, copy and deepcopy for Python arrays (#300 & #367). (#713)
* Implemented pickling and copy for Python arrays(#300 & #367)

* Fixing typos

* Pickle with NumPy arrays

* Pickle: workaround for bfloat16

* Revert "Pickle: workaround for bfloat16"

This reverts commit 25afe6bc09.

* Added an error when pickling bfloat16

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* clang-format applied

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-03-06 08:02:41 -08:00
Angelos Katharopoulos
e39bebe13e Fix reshaping of empty arrays (#791) 2024-03-05 23:33:22 -08:00
Angelos Katharopoulos
14b4e51a7c Improved quantized matrix vector product (#786) 2024-03-05 17:32:19 -08:00
Awni Hannun
cbcf44a4ca Some fixes in cache / thread safety (#777)
* some fixes in cache / thread safety

* speed up no cache case

* fix opt test

* optimizer docs

* otpimizer docs

* fix adafactor

* fix adafactor
2024-03-05 13:30:50 -08:00
Awni Hannun
859ae15a54 Fix test (#785) 2024-03-04 23:02:27 -08:00
Brian Keene
0787724c44 Fast Inference SDPA op (#735)
* Fast Inference SDPA op

Implements metal shaders for:

o = mx.fast_inference_sdpa(queries, keys, values, scale, mask)

Supports fp16, fp32 dtypes; assumes d_k = 128.

Generic op support / prompt encoding supported via mlx primitives.
Metal implementation is for the inference use case only.

Majority of performance benefits appears to results from GQA & reduced
bandwidth requirements; there is approximate performance parity for the
MHA use case (from some measurements on M3 Max).

* Flush shared memory to zero before unprotected reads for (scores @ values)

* Move to fast:: namespace, address reviewer comments

... also attempt to revert formatter auto-change for files not relevant
to this change

* Shared memory flush to top of kernel

* Resolve compiler warnings

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/fast.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update docstring per PR feedback

* Softmax in higher precision, ...

* route to fallback for more use cases - batch size > 1, head_dim other
  than 128, etc.
* Address linux build failure
* Address other reviewer comments

* Remove extraneous eval_cpu function per review

---------

Co-authored-by: Atila Orhon <64497909+atiorh@users.noreply.github.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: atila <atiorh@icloud.com>
2024-03-04 21:06:11 -08:00
Awni Hannun
7b463ffb07 Ios compile (#784)
* try to fix build for ios

* skip cpu compile

* fix namespace

* fix namespace

* Use CMake for platform specific cpu compile

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-03-04 20:02:26 -08:00
Jagrit Digani
6686e61ca4 Reduce update (#783)
* Split reduction files to reduce compile times

* Add small and medium axis size specializations for row reductions

* Add non-row-reduction options for small and med kernels
2024-03-04 19:09:51 -08:00
Awni Hannun
c096a77b9b revision bump (#778) 2024-03-04 13:41:53 -08:00
Awni Hannun
5121f028d9 nice tensordot for mlx c (#782) 2024-03-04 09:51:02 -08:00
Piotr Rybiec
6a665ea6ed Dilation for convolutional layers (#766)
* add dilation parameter to Conv1d layer

* space here too

* add conv1d dilation test

* add dilation parameter for Conv2d layer

* conv2d dilation test
2024-03-04 06:43:00 -08:00
Awni Hannun
bc06cb9ff6 Pickle + dtype fix for numpy conversion (#763)
* pickle + dtype fix for numpy conversion

* fix getattribute on Module base

* remove unused function

* fix tests

* add topk to ops

* fix doc
2024-03-02 06:09:29 -08:00
Angelos Katharopoulos
8e281c76c3 Fix the top-k op (#768) 2024-03-01 22:08:43 -08:00
Awni Hannun
d5964a2710 bindings for memory info (#761)
* bindings for memory info

* update api

* keep cache low if requested

* fix default

* nit in ops error
2024-03-01 19:51:58 -08:00
Ikko Eltociear Ashimine
cf3eb87e52 Fix typo in transforms.cpp (#764)
occuring -> occurring
2024-02-29 22:23:46 -08:00
151 changed files with 11006 additions and 5128 deletions

View File

@@ -31,8 +31,7 @@ jobs:
name: Install dependencies
command: |
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install numpy
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
@@ -44,7 +43,8 @@ jobs:
- run:
name: Generate package stubs
command: |
python3 setup.py generate_stubs
echo "stubs"
python -m nanobind.stubgen -m mlx.core -r -O python
- run:
name: Run Python tests
command: |
@@ -63,21 +63,24 @@ jobs:
command: ./build/tests/tests
mac_build_and_test:
parameters:
xcode_version:
type: string
default: "15.2.0"
macos:
xcode: "15.2.0"
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
python3.9 -m venv env
brew install python@3.8
python3.8 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install pybind11-stubgen
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install numpy
pip install torch
pip install tensorflow
@@ -91,13 +94,13 @@ jobs:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
python -m nanobind.stubgen -m mlx.core -r -O python
- run:
name: Run Python tests
command: |
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
@@ -140,9 +143,8 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install twine
pip install build
@@ -157,7 +159,7 @@ jobs:
name: Generate package stubs
command: |
source env/bin/activate
python setup.py generate_stubs
python -m nanobind.stubgen -m mlx.core -r -O python
- run:
name: Build Python package
command: |
@@ -205,9 +207,8 @@ jobs:
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install --upgrade pybind11[global]
pip install git+https://github.com/wjakob/nanobind.git@4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
pip install --upgrade setuptools
pip install pybind11-stubgen
pip install numpy
pip install auditwheel
pip install patchelf
@@ -215,7 +216,7 @@ jobs:
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
pip install . -v
python setup.py generate_stubs
python -m nanobind.stubgen -m mlx.core -r -O python
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL="" \
python -m build --wheel
@@ -235,7 +236,10 @@ workflows:
- not: << pipeline.parameters.weekly_build >>
- not: << pipeline.parameters.test_release >>
jobs:
- mac_build_and_test
- mac_build_and_test:
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test
build_pypi_release:
@@ -254,7 +258,7 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
prb:
when:
@@ -268,6 +272,9 @@ workflows:
context: pr-approval
- mac_build_and_test:
requires: [ hold ]
matrix:
parameters:
xcode_version: ["15.0.0", "15.2.0"]
- linux_build_and_test:
requires: [ hold ]
nightly_build:
@@ -280,7 +287,7 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
xcode_version: ["15.0.0", "15.2.0"]
weekly_build:
when:
and:
@@ -291,7 +298,7 @@ workflows:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
xcode_version: ["14.3.1", "15.2.0"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:

View File

@@ -13,7 +13,8 @@ MLX was developed with contributions from the following individuals:
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>

View File

@@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.5.0)
set(MLX_VERSION 0.8.0)
endif()
# --------------------- Processor tests -------------------------
@@ -28,7 +28,6 @@ message(STATUS "Building MLX for ${CMAKE_HOST_SYSTEM_PROCESSOR} processor on ${C
set(MLX_BUILD_ARM OFF)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64" AND ${CMAKE_HOST_APPLE})
message(FATAL_ERROR
"Building for x86_64 on macOS is not supported."
@@ -78,10 +77,8 @@ elseif (MLX_BUILD_METAL)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
else()
message(FATAL_ERROR "MLX requires macOS >= 13.4 to be built with MLX_BUILD_METAL=ON" )
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
FetchContent_Declare(
@@ -149,8 +146,12 @@ target_include_directories(
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG REQUIRED)
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR)
list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}")
find_package(nanobind CONFIG REQUIRED)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
endif()

View File

@@ -380,10 +380,6 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.cpu:
mx.set_default_device(mx.cpu)
else:
@@ -406,6 +402,10 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))

View File

@@ -331,10 +331,6 @@ if __name__ == "__main__":
if len(args.axis) > 1:
args.axis.pop(0)
if args.print_pid:
print(os.getpid())
input("Press enter to run")
torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
@@ -354,6 +350,10 @@ if __name__ == "__main__":
x = xs[0]
axis = args.axis[0]
if args.print_pid:
print(os.getpid())
input("Press enter to run")
if args.benchmark == "matmul_square":
print(bench(matmul_square, x))

View File

@@ -29,8 +29,8 @@ autosummary_generate = True
autosummary_filename_map = {"mlx.core.Stream": "stream_class"}
intersphinx_mapping = {
"https://docs.python.org/3": None,
"https://numpy.org/doc/stable/": None,
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
}
templates_path = ["_templates"]
@@ -59,3 +59,14 @@ html_theme_options = {
# -- Options for HTMLHelp output ---------------------------------------------
htmlhelp_basename = "mlx_doc"
def setup(app):
wrapped = app.registry.documenters["function"].can_document_member
def nanobind_function_patch(member: Any, *args, **kwargs) -> bool:
return "nanobind.nb_func" in str(type(member)) or wrapped(
member, *args, **kwargs
)
app.registry.documenters["function"].can_document_member = nanobind_function_patch

View File

@@ -64,6 +64,7 @@ are the CPU and GPU.
python/transforms
python/fft
python/linalg
python/metal
python/nn
python/optimizers
python/tree_utils

View File

@@ -15,10 +15,10 @@ To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.8
- macOS >= 13.3
- macOS >= 13.5
.. note::
MLX is only available on devices running macOS >= 13.3
MLX is only available on devices running macOS >= 13.5
It is highly recommended to use macOS 14 (Sonoma)
@@ -54,7 +54,7 @@ Build Requirements
- A C++ compiler with C++17 support (e.g. Clang >= 5.0)
- `cmake <https://cmake.org/>`_ -- version 3.24 or later, and ``make``
- Xcode >= 14.3 (Xcode >= 15.0 for macOS 14 and above)
- Xcode >= 15.0 and macOS SDK >= 14.0
.. note::
Ensure your shell environment is native ``arm``, not ``x86`` via Rosetta. If
@@ -70,16 +70,13 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Make sure that you have `pybind11 <https://pybind11.readthedocs.io/en/stable/index.html>`_
installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows:
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
.. code-block:: shell
pip install "pybind11[global]"
conda install pybind11
brew install pybind11
pip install git+https://github.com/wjakob/nanobind.git
Then simply build and install it using pip:
Then simply build and install MLX using pip:
.. code-block:: shell

14
docs/src/python/metal.rst Normal file
View File

@@ -0,0 +1,14 @@
Metal
=====
.. currentmodule:: mlx.core.metal
.. autosummary::
:toctree: _autosummary
is_available
get_active_memory
get_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit

View File

@@ -21,9 +21,11 @@ Layers
Embedding
GELU
GroupNorm
GRU
InstanceNorm
LayerNorm
Linear
LSTM
MaxPool1d
MaxPool2d
Mish
@@ -32,6 +34,7 @@ Layers
QuantizedLinear
RMSNorm
ReLU
RNN
RoPE
SELU
Sequential

View File

@@ -57,6 +57,7 @@ Operations
greater_equal
identity
inner
isclose
isnan
isposinf
isneginf
@@ -121,6 +122,8 @@ Operations
tan
tanh
tensordot
tile
topk
transpose
tri
tril

View File

@@ -40,7 +40,7 @@ getting higher order derivatives.
Any of the MLX function transformations can be composed in any order to any
depth. See the following sections for more information on :ref:`automatic
differentiaion <auto diff>` and :ref:`automatic vectorization <vmap>`.
differentiation <auto diff>` and :ref:`automatic vectorization <vmap>`.
For more information on :func:`compile` see the :ref:`compile documentation <compile>`.

View File

@@ -36,22 +36,11 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
init(&cval);
}
array::array(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: array_desc_(std::make_shared<ArrayDesc>(
shape,
dtype,
std::move(primitive),
inputs)) {}
array::array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
std::vector<array> inputs)
: array_desc_(std::make_shared<ArrayDesc>(
std::move(shape),
dtype,
@@ -92,10 +81,10 @@ array::array(std::initializer_list<int> data, Dtype dtype)
/* Build an array from a shared buffer */
array::array(
allocator::Buffer data,
const std::vector<int>& shape,
std::vector<int> shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, deleter);
}
@@ -104,13 +93,11 @@ void array::detach() {
s.array_desc_->inputs.clear();
s.array_desc_->siblings.clear();
s.array_desc_->position = 0;
s.array_desc_->depth = 0;
s.array_desc_->primitive = nullptr;
}
array_desc_->inputs.clear();
array_desc_->siblings.clear();
array_desc_->position = 0;
array_desc_->depth = 0;
array_desc_->primitive = nullptr;
}
@@ -164,41 +151,35 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::move_shared_buffer(array other) {
void array::move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = other.strides();
array_desc_->flags = other.flags();
array_desc_->data_size = other.data_size();
array_desc_->data_ptr = other.array_desc_->data_ptr;
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
}
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape);
void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::ArrayDesc::ArrayDesc(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs)
: shape(shape),
dtype(dtype),
primitive(std::move(primitive)),
inputs(inputs) {
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayDesc::ArrayDesc(
std::vector<int>&& shape,
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs)
std::vector<array> inputs)
: shape(std::move(shape)),
dtype(dtype),
primitive(std::move(primitive)),
@@ -206,9 +187,7 @@ array::ArrayDesc::ArrayDesc(
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
depth = std::max(in.graph_depth(), depth);
}
depth++;
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)

View File

@@ -1,5 +1,6 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <algorithm>
#include <cstdint>
#include <functional>
@@ -31,7 +32,7 @@ class array {
template <typename It>
array(
It data,
const std::vector<int>& shape,
std::vector<int> shape,
Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>());
@@ -47,13 +48,13 @@ class array {
template <typename T>
array(
std::initializer_list<T> data,
const std::vector<int>& shape,
std::vector<int> shape,
Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */
array(
allocator::Buffer data,
const std::vector<int>& shape,
std::vector<int> shape,
Dtype dtype,
deleter_t deleter = allocator::free);
@@ -172,17 +173,11 @@ class array {
* API may change.
*/
array(
const std::vector<int>& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
array(
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
std::vector<array> inputs);
static std::vector<array> make_arrays(
const std::vector<std::vector<int>>& shapes,
@@ -273,11 +268,6 @@ class array {
return outputs;
};
/** The depth of the array in the graph. Evaluated arrays have depth 0. */
uint16_t graph_depth() const {
return array_desc_->depth;
}
/** Detach the array from the graph. */
void detach();
@@ -344,6 +334,13 @@ class array {
void copy_shared_buffer(const array& other);
void move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) {
@@ -360,7 +357,7 @@ class array {
std::vector<size_t> strides;
size_t size;
Dtype dtype;
std::shared_ptr<Primitive> primitive{nullptr};
std::shared_ptr<Primitive> primitive;
// Indicates an array is being used in a graph transform
// and should not be detached from the graph
@@ -368,7 +365,7 @@ class array {
// This is a shared pointer so that *different* arrays
// can share the underlying data buffer.
std::shared_ptr<Data> data{nullptr};
std::shared_ptr<Data> data;
// Properly offset data pointer
void* data_ptr{nullptr};
@@ -388,29 +385,20 @@ class array {
// The arrays position in the output list
uint32_t position{0};
// The depth of the array in the graph.
uint16_t depth{0};
explicit ArrayDesc(const std::vector<int>& shape, Dtype dtype);
explicit ArrayDesc(std::vector<int> shape, Dtype dtype);
explicit ArrayDesc(
const std::vector<int>& shape,
std::vector<int> shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
const std::vector<array>& inputs);
explicit ArrayDesc(
std::vector<int>&& shape,
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array>&& inputs);
std::vector<array> inputs);
};
// The ArrayDesc contains the details of the materialized array including the
// shape, strides, the data type. It also includes
// the primitive which knows how to compute the array's data from its inputs
// and the list of array's inputs for the primitive.
std::shared_ptr<ArrayDesc> array_desc_{nullptr};
std::shared_ptr<ArrayDesc> array_desc_;
};
template <typename T>
@@ -422,9 +410,9 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It>
array::array(
It data,
const std::vector<int>& shape,
std::vector<int> shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
init(data);
}
@@ -441,9 +429,9 @@ array::array(
template <typename T>
array::array(
std::initializer_list<T> data,
const std::vector<int>& shape,
std::vector<int> shape,
Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(shape, dtype)) {
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) {
throw std::invalid_argument(
"Data size and provided shape mismatch in array construction.");

View File

@@ -38,6 +38,7 @@ DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
DEFAULT(Equal)
DEFAULT(Erf)
DEFAULT(ErfInv)
@@ -68,10 +69,13 @@ DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT_MULTI(Split)
DEFAULT(Sort)
DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);

View File

@@ -10,78 +10,65 @@
namespace mlx::core {
template <typename T, typename VT, int N>
void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
VT val = (*(VT*)x);
*(VT*)a += val;
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a++ += *x++;
}
}
}
namespace {
// TODO: Add proper templates for the strided reduce algorithm so we don't have
// to write max/min/sum etc.
template <typename T, typename VT, int N>
void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_max((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::max(*a, *x);
a++;
x++;
}
template <typename T, typename VT>
struct MinReduction {
T operator()(const T& a, const T& b) {
return std::min(a, b);
}
}
template <typename T, typename VT, int N>
void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) {
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = simd_min((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = std::min(*a, *x);
a++;
x++;
}
VT operator()(VT a, VT b) {
return simd_min(a, b);
}
}
};
template <typename T, typename VT, int N>
void _vectorized_sum(const T* x, T* accum, int size) {
VT _sum = {0};
while (size >= N) {
_sum += (*(VT*)x);
x += N;
size -= N;
template <typename T, typename VT>
struct MaxReduction {
T operator()(const T& a, const T& b) {
return std::max(a, b);
}
T sum = _sum[0];
for (int i = 1; i < N; i++) {
sum += _sum[i];
VT operator()(VT a, VT b) {
return simd_max(a, b);
}
*accum += sum;
}
};
template <typename T, typename VT>
struct SumReduction {
T operator()(const T& a, const T& b) {
return a + b;
}
VT operator()(VT a, VT b) {
return a + b;
}
};
template <typename T, typename VT, int N, typename Reduction>
struct StridedReduce {
void operator()(const T* x, T* accum, int size, size_t stride) {
Reduction op;
for (int i = 0; i < size; i++) {
size_t s = stride;
T* a = accum;
while (s >= N) {
*(VT*)a = op((*(VT*)x), (*(VT*)a));
x += N;
a += N;
s -= N;
}
while (s-- > 0) {
*a = op(*a, *x);
a++;
x++;
}
}
}
};
} // namespace
void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
@@ -94,10 +81,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
0,
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_sum<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
StridedReduce<
float,
simd_float16,
16,
SumReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float acc;
vDSP_sve((const float*)x, 1, &acc, size);
@@ -111,10 +99,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
-std::numeric_limits<float>::infinity(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_max<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
StridedReduce<
float,
simd_float16,
16,
MaxReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float max;
vDSP_maxv((const float*)x, 1, &max, size);
@@ -128,10 +117,11 @@ void Reduce::eval_cpu(const std::vector<array>& inputs, array& out) {
out,
axes_,
std::numeric_limits<float>::infinity(),
[](const auto* x, auto* accum, int size, size_t stride) {
_vectorized_strided_min<float, simd_float16, 16>(
(const float*)x, (float*)accum, size, stride);
},
StridedReduce<
float,
simd_float16,
16,
MinReduction<float, simd_float16>>(),
[](const auto* x, auto* accum, int size) {
float min;
vDSP_minv((const float*)x, 1, &min, size);

View File

@@ -44,7 +44,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
@@ -53,5 +52,21 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)
if (IOS)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp
)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
)
endif()

View File

@@ -1,57 +1,12 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include <filesystem>
#include <list>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/backend/common/utils.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
std::ostringstream os;
std::ostringstream constant_hasher;
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
a.primitive().print(os);
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
void print_constant(std::ostream& os, const array& x) {
switch (x.dtype()) {
case float32:
@@ -122,386 +77,53 @@ std::string get_type_string(Dtype d) {
}
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
if (!lib) {
std::ostringstream msg;
msg << "Could not load C++ shared library " << dlerror();
throw std::runtime_error(msg.str());
}
}
~DLib() {
dlclose(lib);
}
void* lib;
};
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
if (source_code.empty()) {
return nullptr;
}
std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str());
bool lib_exists = false;
{
std::ifstream f(shared_lib_path.c_str());
lib_exists = f.good();
}
if (!lib_exists) {
// Open source file and write source code to it
std::ostringstream source_file_name;
source_file_name << kernel_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::ofstream source_file(source_file_path);
source_file << source_code;
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
if (return_code) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
<< " with error code " << return_code << "." << std::endl;
throw std::runtime_error(msg.str());
}
}
// load library
libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
<< kernel_name << std::endl
<< dlerror();
throw std::runtime_error(msg.str());
}
kernels.insert({kernel_name, fun});
return fun;
}
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
namer.get_name(x);
}
// Skip constants from the input list
if (is_constant(x)) {
continue;
}
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
// Add the output arguments
for (auto& x : outputs) {
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
if (contiguous) {
os << " for (size_t i = 0; i < size; ++i) {" << std::endl;
} else {
for (int d = 0; d < ndim; ++d) {
os << " for (int i" << d << " = 0; i" << d << " < shape[" << d
<< "]; ++i" << d << ") {" << std::endl;
}
}
// Read the inputs in tmps
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[i];" << std::endl;
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *"
<< xname << ";" << std::endl;
os << (is_scalar(x) ? "S" : "V");
}
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
if (contiguous) {
os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
} else {
os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x)
<< ";" << std::endl;
}
}
// Close loops
if (contiguous) {
os << " }" << std::endl;
} else {
for (int d = ndim - 1; d >= 0; --d) {
// Update pointers
for (auto& x : inputs) {
if (is_constant(x) || is_scalar(x)) {
continue;
}
auto& xname = namer.get_name(x);
os << " " << xname << " += " << xname << "_strides[" << d << "];"
<< std::endl;
if (d < ndim - 1) {
os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]"
<< " * shape[" << d + 1 << "];" << std::endl;
}
}
os << " }" << std::endl;
}
}
// Finish the kernel
os << "}" << std::endl;
}
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
bool contiguous = true;
{
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
}
}
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
std::vector<std::vector<size_t>> strides;
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
auto& x = inputs[i];
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
continue;
}
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
// Get the kernel name from the lib
int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
kernel_name += std::to_string(shape.size());
}
// Get the function
auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
build_kernel(
kernel,
kernel_name,
inputs_,
outputs_,
tape_,
constant_ids_,
contiguous,
ndim);
// Close extern "C"
kernel << "}" << std::endl;
// Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
// Allocate space for the outputs possibly with input donation
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
for (auto& x : outputs) {
args.push_back(x.data<void>());
}
if (!contiguous) {
args.push_back((void*)outputs[0].shape().data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
fun(args.data());
return os.str();
}
} // namespace mlx::core

View File

@@ -0,0 +1,410 @@
// Copyright © 2023-2024 Apple Inc.
#include <dlfcn.h>
#include <filesystem>
#include <list>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"
namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
if (!lib) {
std::ostringstream msg;
msg << "Could not load C++ shared library " << dlerror();
throw std::runtime_error(msg.str());
}
}
~DLib() {
dlclose(lib);
}
void* lib;
};
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
if (source_code.empty()) {
return nullptr;
}
std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str());
bool lib_exists = false;
{
std::ifstream f(shared_lib_path.c_str());
lib_exists = f.good();
}
if (!lib_exists) {
// Open source file and write source code to it
std::ostringstream source_file_name;
source_file_name << kernel_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::ofstream source_file(source_file_path);
source_file << source_code;
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
if (return_code) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
<< " with error code " << return_code << "." << std::endl;
throw std::runtime_error(msg.str());
}
}
// load library
libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
<< kernel_name << std::endl
<< dlerror();
throw std::runtime_error(msg.str());
}
kernels.insert({kernel_name, fun});
return fun;
}
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids,
bool contiguous,
int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer;
// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;
// Add the input arguments
int cnt = 0;
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
// Skip constants from the input list
if (is_constant(x)) {
continue;
}
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl;
// Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) {
os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++
<< "];" << std::endl;
}
}
// Add the output arguments
for (auto& x : outputs) {
auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr
<< "*)args[" << cnt++ << "];" << std::endl;
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl;
} else {
os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl;
}
if (contiguous) {
os << " for (size_t i = 0; i < size; ++i) {" << std::endl;
} else {
for (int d = 0; d < ndim; ++d) {
os << " for (int i" << d << " = 0; i" << d << " < shape[" << d
<< "]; ++i" << d << ") {" << std::endl;
}
}
// Read the inputs in tmps
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ";" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[i];" << std::endl;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = *"
<< xname << ";" << std::endl;
}
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
if (contiguous) {
os << " " << namer.get_name(x) << "[i] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
} else {
os << " *" << namer.get_name(x) << "++ = tmp_" << namer.get_name(x)
<< ";" << std::endl;
}
}
// Close loops
if (contiguous) {
os << " }" << std::endl;
} else {
for (int d = ndim - 1; d >= 0; --d) {
// Update pointers
for (auto& x : inputs) {
if (is_constant(x) || is_scalar(x)) {
continue;
}
auto& xname = namer.get_name(x);
os << " " << xname << " += " << xname << "_strides[" << d << "];"
<< std::endl;
if (d < ndim - 1) {
os << " " << xname << " -= " << xname << "_strides[" << d + 1 << "]"
<< " * shape[" << d + 1 << "];" << std::endl;
}
}
os << " }" << std::endl;
}
}
// Finish the kernel
os << "}" << std::endl;
}
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
bool contiguous = true;
{
bool all_contig = true;
bool all_row_contig = true;
bool all_col_contig = true;
int non_scalar_inputs = 0;
for (auto& x : inputs) {
if (is_scalar(x)) {
continue;
}
non_scalar_inputs++;
bool shape_eq = x.shape() == shape;
all_contig &= (x.flags().contiguous && shape_eq);
all_row_contig &= (x.flags().row_contiguous && shape_eq);
all_col_contig &= (x.flags().col_contiguous && shape_eq);
}
if (non_scalar_inputs > 1 && !all_row_contig && !all_col_contig) {
contiguous = false;
} else if (non_scalar_inputs == 1 && !all_contig) {
contiguous = false;
}
}
// Handle all broadcasting and collect function input arguments
std::vector<void*> args;
std::vector<std::vector<size_t>> strides;
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
args.push_back((void*)x.data<void>());
if (contiguous || is_scalar(x)) {
continue;
}
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
}
// Get the kernel name from the lib
int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) {
kernel_name += std::to_string(shape.size());
}
// Get the function
auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
build_kernel(
kernel,
kernel_name,
inputs_,
outputs_,
tape_,
constant_ids_,
contiguous,
ndim);
// Close extern "C"
kernel << "}" << std::endl;
// Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
// Allocate space for the outputs possibly with input donation
if (contiguous) {
int o = 0;
std::vector<size_t> strides;
size_t data_size;
array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().contiguous && !is_scalar(in) && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
}
// Get representative input flags to properly set non-donated outputs
if (strides.empty() && in.size() == outputs[0].size()) {
strides = in.strides();
flags = in.flags();
data_size = in.data_size();
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(
allocator::malloc_or_wait(data_size * outputs[o].itemsize()),
data_size,
strides,
flags);
}
} else {
int o = 0;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {
auto& in = inputs[i];
// Conditions for donation
// - Row contiguous
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {
outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes()));
}
}
for (auto& x : outputs) {
args.push_back(x.data<void>());
}
if (!contiguous) {
args.push_back((void*)outputs[0].shape().data());
} else {
args.push_back((void*)outputs[0].data_size());
}
auto fun = (void (*)(void**))fn_ptr;
fun(args.data());
}
} // namespace mlx::core

View File

@@ -0,0 +1,23 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h"
namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is not available so check if the device is a GPU
// device.
namespace detail {
bool compile_available_for_device(const Device& device) {
return device == Device::gpu;
}
} // namespace detail
void Compiled::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error(
"[Compiled::eval_cpu] CPU compialtion not supported on the platform.");
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <numeric>
@@ -25,121 +25,196 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
}
template <typename SrcT, typename DstT>
void copy_general_dim1(const array& src, array& dst) {
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim1(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[0];
src_idx += i_strides[0];
}
}
template <typename SrcT, typename DstT>
void copy_general_dim2(const array& src, array& dst) {
inline void copy_general_dim1(const array& src, array& dst) {
return copy_general_dim1<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim2(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[1];
src_idx += i_strides[1];
}
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
void copy_general_dim3(const array& src, array& dst) {
inline void copy_general_dim2(const array& src, array& dst) {
return copy_general_dim2<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim3(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) {
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[2];
src_idx += i_strides[2];
}
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
}
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
void copy_general_dim4(const array& src, array& dst) {
inline void copy_general_dim3(const array& src, array& dst) {
return copy_general_dim3<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_dim4(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
const SrcT* src_ptr = src.data<SrcT>();
DstT* dst_ptr = dst.data<DstT>();
size_t src_idx = 0;
size_t dst_idx = 0;
for (size_t i = 0; i < src.shape()[0]; ++i) {
for (size_t j = 0; j < src.shape()[1]; ++j) {
for (size_t k = 0; k < src.shape()[2]; ++k) {
for (size_t ii = 0; ii < src.shape()[3]; ++ii) {
stride_t src_idx = i_offset;
stride_t dst_idx = 0;
for (int i = 0; i < data_shape[0]; ++i) {
for (int j = 0; j < data_shape[1]; ++j) {
for (int k = 0; k < data_shape[2]; ++k) {
for (int ii = 0; ii < data_shape[3]; ++ii) {
dst_ptr[dst_idx++] = static_cast<DstT>(src_ptr[src_idx]);
src_idx += src.strides()[3];
src_idx += i_strides[3];
}
src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3];
src_idx += i_strides[2] - i_strides[3] * data_shape[3];
}
src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2];
src_idx += i_strides[1] - i_strides[2] * data_shape[2];
}
src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1];
src_idx += i_strides[0] - i_strides[1] * data_shape[1];
}
}
template <typename SrcT, typename DstT>
void copy_general(const array& src, array& dst) {
inline void copy_general_dim4(const array& src, array& dst) {
return copy_general_dim4<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
switch (src.ndim()) {
case 1:
copy_general_dim1<SrcT, DstT>(src, dst);
copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
case 2:
copy_general_dim2<SrcT, DstT>(src, dst);
copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
case 3:
copy_general_dim3<SrcT, DstT>(src, dst);
copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
case 4:
copy_general_dim4<SrcT, DstT>(src, dst);
copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
return;
}
auto src_ptr = src.data<SrcT>();
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
size_t src_elem = elem_to_loc(i, src.shape(), src.strides());
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
template <typename SrcT, typename DstT, int D>
template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) {
return copy_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), 0);
}
template <typename SrcT, typename DstT, typename stride_t>
inline void copy_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
return copy_general<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
}
template <typename SrcT, typename DstT, typename stride_t, int D>
inline void copy_general_general_dims(
const array& src,
array& dst,
size_t offset_src,
size_t offset_dst) {
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
if constexpr (D > 1) {
int axis = src.ndim() - D;
auto stride_src = src.strides()[axis];
auto stride_dst = dst.strides()[axis];
auto N = src.shape(axis);
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
for (int i = 0; i < N; i++) {
copy_general_general_dims<SrcT, DstT, D - 1>(
src, dst, offset_src, offset_dst);
offset_src += stride_src;
offset_dst += stride_dst;
copy_general_general_dims<SrcT, DstT, stride_t, D - 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
i_offset += stride_src;
o_offset += stride_dst;
}
} else {
int axis = src.ndim() - 1;
auto stride_src = src.strides()[axis];
auto stride_dst = dst.strides()[axis];
auto N = src.shape(axis);
const SrcT* src_ptr = src.data<SrcT>() + offset_src;
DstT* dst_ptr = dst.data<DstT>() + offset_dst;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
const SrcT* src_ptr = src.data<SrcT>() + i_offset;
DstT* dst_ptr = dst.data<DstT>() + o_offset;
for (int i = 0; i < N; i++) {
*dst_ptr = static_cast<DstT>(*src_ptr);
src_ptr += stride_src;
@@ -148,37 +223,56 @@ inline void copy_general_general_dims(
}
}
template <typename SrcT, typename DstT>
void copy_general_general(const array& src, array& dst) {
template <typename SrcT, typename DstT, typename stride_t>
void copy_general_general(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
switch (src.ndim()) {
case 1:
copy_general_general_dims<SrcT, DstT, 1>(src, dst, 0, 0);
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 2:
copy_general_general_dims<SrcT, DstT, 2>(src, dst, 0, 0);
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 3:
copy_general_general_dims<SrcT, DstT, 3>(src, dst, 0, 0);
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 4:
copy_general_general_dims<SrcT, DstT, 4>(src, dst, 0, 0);
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 5:
copy_general_general_dims<SrcT, DstT, 5>(src, dst, 0, 0);
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
}
int size = std::accumulate(
src.shape().begin() - 5, src.shape().end(), 1, std::multiplies<int>());
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) {
size_t offset_src = elem_to_loc(i, src.shape(), src.strides());
size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides());
copy_general_general_dims<SrcT, DstT, 5>(src, dst, offset_src, offset_dst);
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
}
}
template <typename SrcT, typename DstT>
void copy(const array& src, array& dst, CopyType ctype) {
inline void copy_general_general(const array& src, array& dst) {
return copy_general_general<SrcT, DstT, size_t>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
}
template <typename SrcT, typename DstT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
switch (ctype) {
case CopyType::Scalar:
copy_single<SrcT, DstT>(src, dst);
@@ -187,54 +281,103 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_vector<SrcT, DstT>(src, dst);
return;
case CopyType::General:
copy_general<SrcT, DstT>(src, dst);
copy_general<SrcT, DstT>(src, dst, args...);
return;
case CopyType::GeneralGeneral:
copy_general_general<SrcT, DstT>(src, dst);
copy_general_general<SrcT, DstT>(src, dst, args...);
}
}
template <typename SrcT>
void copy(const array& src, array& dst, CopyType ctype) {
template <typename SrcT, typename... Args>
void copy(const array& src, array& dst, CopyType ctype, Args... args) {
switch (dst.dtype()) {
case bool_:
copy<SrcT, bool>(src, dst, ctype);
copy<SrcT, bool>(src, dst, ctype, args...);
break;
case uint8:
copy<SrcT, uint8_t>(src, dst, ctype);
copy<SrcT, uint8_t>(src, dst, ctype, args...);
break;
case uint16:
copy<SrcT, uint16_t>(src, dst, ctype);
copy<SrcT, uint16_t>(src, dst, ctype, args...);
break;
case uint32:
copy<SrcT, uint32_t>(src, dst, ctype);
copy<SrcT, uint32_t>(src, dst, ctype, args...);
break;
case uint64:
copy<SrcT, uint64_t>(src, dst, ctype);
copy<SrcT, uint64_t>(src, dst, ctype, args...);
break;
case int8:
copy<SrcT, int8_t>(src, dst, ctype);
copy<SrcT, int8_t>(src, dst, ctype, args...);
break;
case int16:
copy<SrcT, int16_t>(src, dst, ctype);
copy<SrcT, int16_t>(src, dst, ctype, args...);
break;
case int32:
copy<SrcT, int32_t>(src, dst, ctype);
copy<SrcT, int32_t>(src, dst, ctype, args...);
break;
case int64:
copy<SrcT, int64_t>(src, dst, ctype);
copy<SrcT, int64_t>(src, dst, ctype, args...);
break;
case float16:
copy<SrcT, float16_t>(src, dst, ctype);
copy<SrcT, float16_t>(src, dst, ctype, args...);
break;
case float32:
copy<SrcT, float>(src, dst, ctype);
copy<SrcT, float>(src, dst, ctype, args...);
break;
case bfloat16:
copy<SrcT, bfloat16_t>(src, dst, ctype);
copy<SrcT, bfloat16_t>(src, dst, ctype, args...);
break;
case complex64:
copy<SrcT, complex64_t>(src, dst, ctype);
copy<SrcT, complex64_t>(src, dst, ctype, args...);
break;
}
}
template <typename... Args>
inline void copy_inplace_dispatch(
const array& src,
array& dst,
CopyType ctype,
Args... args) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype, args...);
break;
case uint8:
copy<uint8_t>(src, dst, ctype, args...);
break;
case uint16:
copy<uint16_t>(src, dst, ctype, args...);
break;
case uint32:
copy<uint32_t>(src, dst, ctype, args...);
break;
case uint64:
copy<uint64_t>(src, dst, ctype, args...);
break;
case int8:
copy<int8_t>(src, dst, ctype, args...);
break;
case int16:
copy<int16_t>(src, dst, ctype, args...);
break;
case int32:
copy<int32_t>(src, dst, ctype, args...);
break;
case int64:
copy<int64_t>(src, dst, ctype, args...);
break;
case float16:
copy<float16_t>(src, dst, ctype, args...);
break;
case float32:
copy<float>(src, dst, ctype, args...);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype, args...);
break;
case complex64:
copy<complex64_t>(src, dst, ctype, args...);
break;
}
}
@@ -242,47 +385,7 @@ void copy(const array& src, array& dst, CopyType ctype) {
} // namespace
void copy_inplace(const array& src, array& dst, CopyType ctype) {
switch (src.dtype()) {
case bool_:
copy<bool>(src, dst, ctype);
break;
case uint8:
copy<uint8_t>(src, dst, ctype);
break;
case uint16:
copy<uint16_t>(src, dst, ctype);
break;
case uint32:
copy<uint32_t>(src, dst, ctype);
break;
case uint64:
copy<uint64_t>(src, dst, ctype);
break;
case int8:
copy<int8_t>(src, dst, ctype);
break;
case int16:
copy<int16_t>(src, dst, ctype);
break;
case int32:
copy<int32_t>(src, dst, ctype);
break;
case int64:
copy<int64_t>(src, dst, ctype);
break;
case float16:
copy<float16_t>(src, dst, ctype);
break;
case float32:
copy<float>(src, dst, ctype);
break;
case bfloat16:
copy<bfloat16_t>(src, dst, ctype);
break;
case complex64:
copy<complex64_t>(src, dst, ctype);
break;
}
return copy_inplace_dispatch(src, dst, ctype);
}
void copy(const array& src, array& dst, CopyType ctype) {
@@ -312,4 +415,62 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype);
}
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
template <>
void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<int64_t>& i_strides,
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@@ -26,4 +26,15 @@ enum class CopyType {
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
template <typename stride_t>
void copy_inplace(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core

View File

@@ -51,6 +51,7 @@ DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
DEFAULT(Remainder)
DEFAULT(Equal)
DEFAULT(Erf)
@@ -93,6 +94,7 @@ DEFAULT(Sign)
DEFAULT(Sin)
DEFAULT(Sinh)
DEFAULT(Slice)
DEFAULT(SliceUpdate)
DEFAULT(Softmax)
DEFAULT(Sort)
DEFAULT_MULTI(Split)
@@ -100,9 +102,11 @@ DEFAULT(Square)
DEFAULT(Sqrt)
DEFAULT(StopGradient)
DEFAULT(Subtract)
DEFAULT_MULTI(SVD)
DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
namespace {

View File

@@ -0,0 +1,104 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
void inverse_impl(const array& a, array& inv) {
// Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see
// https://math.stackexchange.com/a/340234):
// (A⁻¹)ᵀ = (Aᵀ)⁻¹
// The inverse is computed in place, so just copy the input to the output.
copy(a, inv, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
for (int i = 0; i < num_matrices; i++) {
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch =
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
}
void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32.");
}
inverse_impl(inputs[0], output);
}
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::inv(a, stream())}, {ax}};
}
} // namespace mlx::core

View File

@@ -0,0 +1,23 @@
// Copyright © 2024 Apple Inc.
#pragma once
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
// This is to work around a change in the function signatures of lapack >= 3.9.1
// where functions taking char* also include a strlen argument, see a similar
// change in OpenCV:
// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
#define MLX_LAPACK_FUNC(f) LAPACK_##f
#else
#define MLX_LAPACK_FUNC(f) f##_
#endif

View File

@@ -251,6 +251,62 @@ void Depends::eval(
}
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
numel *= inputs[0].shape(ax);
}
if (inverted_) {
numel = 1.0 / numel;
}
switch (out.dtype()) {
case bool_:
*out.data<bool>() = static_cast<bool>(numel);
break;
case uint8:
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
break;
case uint16:
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
break;
case uint32:
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
break;
case uint64:
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
break;
case int8:
*out.data<int8_t>() = static_cast<int8_t>(numel);
break;
case int16:
*out.data<int16_t>() = static_cast<int16_t>(numel);
break;
case int32:
*out.data<int32_t>() = static_cast<int32_t>(numel);
break;
case int64:
*out.data<int64_t>() = static_cast<int64_t>(numel);
break;
case float16:
*out.data<float16_t>() = static_cast<float16_t>(numel);
break;
case float32:
*out.data<float>() = static_cast<float>(numel);
break;
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -468,20 +524,73 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
}
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (in.flags().row_contiguous) {
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
}
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
}
// Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
int N = out.shape(i);
if (j < shape.size() && shape[j] % N == 0) {
shape[j] /= N;
out_strides.push_back(shape[j] * strides[j]);
j += (shape[j] == 1);
} else if (N == 1) {
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
out_strides.push_back(out_strides.back());
} else {
copy_necessary = true;
break;
}
}
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
// For row contiguous reshapes:
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again.
auto flags = in.flags();
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
} else {
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
@@ -542,36 +651,33 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
}
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto strides = in.strides();
auto flags = in.flags();
size_t data_offset = 0;
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
const array& in) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
strides[i] *= strides_[i];
inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1;
flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1;
f_stride *= out.shape(i);
b_stride *= out.shape(ri);
if (strides[i] > 0) {
data_size *= out.shape(i);
}
}
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
@@ -585,7 +691,87 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, strides, flags, data_size, data_offset);
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ in,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General);
} else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
}
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out);
// Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace<int64_t>(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd_strides,
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void Split::eval(

View File

@@ -1,13 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {
void RoPE::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("NYI");
}
} // namespace mlx::core::fast

View File

@@ -67,11 +67,15 @@ void Softmax::eval(const std::vector<array>& inputs, array& out) {
}
};
array in = check_input(std::move(inputs[0]));
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
if (in.is_donatable()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
switch (in.dtype()) {
case bool_:

156
mlx/backend/common/svd.cpp Normal file
View File

@@ -0,0 +1,156 @@
// Copyright © 2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
void svd_impl(const array& a, array& u, array& s, array& vt) {
// Lapack uses the column-major convention. To avoid having to transpose
// the input and then transpose the outputs, we swap the indices/sizes of the
// matrices and take advantage of the following identity (see
// https://math.stackexchange.com/a/30077)
// A = UΣVᵀ
// Aᵀ = VΣUᵀ
// As a result some of the indices/sizes are swapped as noted above.
// Rows and cols of the original matrix in row-major order.
const int M = a.shape(-2);
const int N = a.shape(-1);
const int K = std::min(M, N);
// A of shape M x N. The leading dimension is N since lapack receives Aᵀ.
const int lda = N;
// U of shape M x M. (N x N in lapack).
const int ldu = N;
// Vᵀ of shape N x N. (M x M in lapack).
const int ldvt = M;
size_t num_matrices = a.size() / (M * N);
// lapack clobbers the input, so we have to make a copy.
array in(a.shape(), float32, nullptr, {});
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
// Allocate outputs.
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
static constexpr auto job_u = "V";
static constexpr auto job_vt = "V";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned.
int ns = 0;
float workspace_dimension = 0;
// Will contain the indices of eigenvectors that failed to converge (not used
// here but required by lapack).
auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)};
static const int lwork_query = -1;
static const int ignored_int = 0;
static const float ignored_float = 0;
int info;
// Compute workspace size.
MLX_LAPACK_FUNC(sgesvdx)
(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ nullptr,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ nullptr,
/* u = */ nullptr,
/* ldu = */ &ldu,
/* vt = */ nullptr,
/* ldvt = */ &ldvt,
/* work = */ &workspace_dimension,
/* lwork = */ &lwork_query,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ workspace calculation failed with code " << info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_dimension;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Loop over matrices.
for (int i = 0; i < num_matrices; i++) {
MLX_LAPACK_FUNC(sgesvdx)
(
/* jobu = */ job_u,
/* jobvt = */ job_vt,
/* range = */ range,
// M and N are swapped since lapack expects column-major.
/* m = */ &N,
/* n = */ &M,
/* a = */ in.data<float>() + M * N * i,
/* lda = */ &lda,
/* vl = */ &ignored_float,
/* vu = */ &ignored_float,
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s.data<float>() + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt.data<float>() + N * N * i,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u.data<float>() + M * M * i,
/* ldvt = */ &ldvt,
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* iwork = */ static_cast<int*>(iwork.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "svd_impl: sgesvdx_ failed with code " << info;
throw std::runtime_error(ss.str());
}
if (ns != K) {
std::stringstream ss;
ss << "svd_impl: expected " << K << " singular values, but " << ns
<< " were computed.";
throw std::runtime_error(ss.str());
}
}
}
void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
if (!(inputs[0].dtype() == float32)) {
throw std::runtime_error("[SVD::eval] only supports float32.");
}
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@@ -8,11 +8,12 @@
namespace mlx::core {
inline size_t elem_to_loc(
template <typename stride_t>
inline stride_t elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
size_t loc = 0;
const std::vector<stride_t>& strides) {
stride_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i];
@@ -28,4 +29,94 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides());
}
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
template <typename stride_t>
inline std::tuple<std::vector<int>, std::vector<std::vector<stride_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<stride_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (const std::vector<stride_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<stride_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<stride_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
}
template <typename... Arrays>
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(Arrays... xs) {
return collapse_contiguous_dims(
std::vector<array>{std::forward<Arrays>(xs)...});
}
template <typename stride_t>
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<stride_t>& strides) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
bool is_row_contiguous = true;
bool is_col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous);
}
} // namespace mlx::core

View File

@@ -29,9 +29,11 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
@@ -23,16 +22,6 @@ void* Buffer::raw_ptr() {
namespace metal {
static bool cache_enabled_ = true;
bool cache_enabled() {
return cache_enabled_;
}
void set_cache_enabled(bool enabled) {
cache_enabled_ = enabled;
}
namespace {
BufferCache::BufferCache(MTL::Device* device)
@@ -44,7 +33,6 @@ BufferCache::~BufferCache() {
}
void BufferCache::clear() {
std::lock_guard<std::mutex> lk(cache_mutex_);
for (auto& [size, holder] : buffer_pool_) {
if (holder->buf)
holder->buf->release();
@@ -57,12 +45,9 @@ void BufferCache::clear() {
}
MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Find the closest buffer in pool
MTL::Buffer* pbuf = nullptr;
// Make sure we use most of the available memory
auto it = buffer_pool_.lower_bound(size);
// Make sure we use most of the available memory
@@ -85,8 +70,6 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) {
}
void BufferCache::recycle_to_cache(MTL::Buffer* buf) {
std::lock_guard<std::mutex> lk(cache_mutex_);
// Add to cache
if (buf) {
BufferHolder* bh = new BufferHolder(buf);
@@ -100,7 +83,6 @@ void BufferCache::release_cached_buffers(size_t min_bytes_to_free) {
if (min_bytes_to_free >= 0.9 * pool_size_) {
clear();
} else {
std::lock_guard<std::mutex> lk(cache_mutex_);
size_t total_bytes_freed = 0;
while (tail_ && (total_bytes_freed < min_bytes_to_free)) {
@@ -158,9 +140,23 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_),
peak_allocated_size_(0),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {}
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()),
max_pool_size_(block_limit_) {}
size_t MetalAllocator::set_cache_limit(size_t limit) {
std::swap(limit, max_pool_size_);
return limit;
};
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
std::swap(limit, block_limit_);
relaxed_ = relaxed;
gc_limit_ = std::min(
block_limit_,
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()));
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers
@@ -174,41 +170,53 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
}
// Try the cache
std::unique_lock lk(mutex_);
MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) {
size_t mem_required = get_active_memory() + get_cache_memory() + size;
// If there is too much memory pressure, fail (likely causes a wait).
if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) {
if (!(allow_swap && relaxed_) && mem_required >= block_limit_) {
return Buffer{nullptr};
}
auto thread_pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure, check if we can reclaim some memory
// from the cache
if (device_->currentAllocatedSize() + size >= gc_limit_) {
size_t min_bytes_to_free =
size + device_->currentAllocatedSize() - gc_limit_;
buffer_cache_.release_cached_buffers(min_bytes_to_free);
// If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache
if (mem_required >= gc_limit_) {
buffer_cache_.release_cached_buffers(mem_required - gc_limit_);
}
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked;
lk.unlock();
buf = device_->newBuffer(size, res_opt);
lk.lock();
}
peak_allocated_size_ =
std::max(peak_allocated_size_, device_->currentAllocatedSize());
active_memory_ += buf->length();
peak_memory_ = std::max(peak_memory_, active_memory_);
// Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) {
auto thread_pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
return Buffer{static_cast<void*>(buf)};
}
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (cache_enabled()) {
std::unique_lock lk(mutex_);
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
} else {
lk.unlock();
auto thread_pool = metal::new_scoped_memory_pool();
buf->release();
}
}
@@ -218,6 +226,22 @@ MetalAllocator& allocator() {
return allocator_;
}
size_t set_cache_limit(size_t limit) {
return allocator().set_cache_limit(limit);
}
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
}
size_t get_active_memory() {
return allocator().get_active_memory();
}
size_t get_peak_memory() {
return allocator().get_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
}
} // namespace metal
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@@ -19,11 +19,13 @@ class BufferCache {
public:
BufferCache(MTL::Device* device);
~BufferCache();
void clear();
MTL::Buffer* reuse_from_cache(size_t size);
void recycle_to_cache(MTL::Buffer* buf);
void release_cached_buffers(size_t min_bytes_to_free);
size_t cache_size() {
return pool_size_;
}
private:
struct BufferHolder {
@@ -35,11 +37,11 @@ class BufferCache {
MTL::Buffer* buf;
};
void clear();
void add_at_head(BufferHolder* to_add);
void remove_from_list(BufferHolder* to_remove);
MTL::Device* device_;
std::mutex cache_mutex_;
std::multimap<size_t, BufferHolder*> buffer_pool_;
BufferHolder* head_;
@@ -54,6 +56,17 @@ class MetalAllocator : public allocator::Allocator {
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override;
size_t get_active_memory() {
return active_memory_;
};
size_t get_peak_memory() {
return peak_memory_;
};
size_t get_cache_memory() {
return buffer_cache_.cache_size();
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
private:
MTL::Device* device_;
@@ -64,9 +77,14 @@ class MetalAllocator : public allocator::Allocator {
BufferCache buffer_cache_;
// Allocation stats
size_t peak_allocated_size_;
size_t block_limit_;
size_t gc_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
size_t max_pool_size_;
bool relaxed_{true};
std::mutex mutex_;
};
MetalAllocator& allocator();

View File

@@ -3,6 +3,7 @@
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
@@ -329,7 +330,9 @@ void Compiled::eval_gpu(
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].move_shared_buffer(in);
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {

View File

@@ -28,10 +28,12 @@ void explicit_gemm_conv_ND_gpu(
const array& wt,
array out,
const MLXConvParams<N>& conv_params) {
// Get gemm shapes
int implicit_M = out.size() / conv_params.O;
int implicit_K = wt.size() / conv_params.O;
int implicit_N = conv_params.O;
// Prepare unfolding array
std::vector<int> unfolded_shape = {
static_cast<int>(out.size() / conv_params.O),
static_cast<int>(wt.size() / conv_params.O)};
std::vector<int> unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
@@ -59,20 +61,29 @@ void explicit_gemm_conv_ND_gpu(
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)};
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
auto wt_flags = wt.flags();
wt_flags.row_contiguous = false;
wt_flags.col_contiguous = true;
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
// Perform gemm
std::vector<array> copies;
std::vector<array> copies = {in_unfolded, wt_reshaped};
return steel_matmul(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt,
/*b = */ wt_reshaped,
/*c = */ out,
/*M = */ unfolded_shape[0],
/*N = */ conv_params.O,
/*K = */ unfolded_shape[1],
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*batch_size_out = */ 1,
/*a_cols = */ unfolded_shape[1],
/*b_cols = */ unfolded_shape[1],
/*a_cols = */ implicit_K,
/*b_cols = */ implicit_K,
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <sstream>
@@ -37,15 +37,22 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream());
}
template <typename stride_t>
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int>& data_shape,
const std::vector<stride_t>& strides_in_pre,
const std::vector<stride_t>& strides_out_pre,
int64_t inp_offset,
int64_t out_offset,
CopyType ctype,
const Stream& s) {
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(in, out);
auto& strides_in = strides[0];
auto& strides_out = strides[1];
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector{strides_in_pre, strides_out_pre});
auto& strides_in_ = strides[0];
auto& strides_out_ = strides[1];
auto& d = metal::device(s.device);
std::ostringstream kname;
@@ -72,39 +79,44 @@ void copy_gpu_inplace(
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
inp_offset *= size_of(in.dtype());
out_offset *= size_of(out.dtype());
set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0);
set_array_buffer(compute_encoder, out, out_offset, 1);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
size_t ndim = shape.size();
int ndim = shape.size();
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4);
}
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2);
if (ctype == CopyType::GeneralGeneral) {
compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3);
}
set_vector_bytes(compute_encoder, shape, ndim, 2);
}
set_vector_bytes(compute_encoder, strides_in, ndim, 3);
if (ctype == CopyType::GeneralGeneral) {
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(
&ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4);
compute_encoder->setBytes(&ndim, sizeof(int), 5);
}
int dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1;
int rest = in.size() / (dim0 * dim1);
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
int rest = data_size / (dim0 * dim1);
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
@@ -120,4 +132,25 @@ void copy_gpu_inplace(
}
}
void copy_gpu_inplace(
const array& in,
array& out,
CopyType ctype,
const Stream& s) {
return copy_gpu_inplace(
in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s);
}
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int64_t>& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s) {
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
return copy_gpu_inplace(
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s);
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@@ -7,12 +7,34 @@
namespace mlx::core {
// Generic copy inplace
template <typename stride_t>
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype,
const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s);
void copy_gpu(const array& src, array& out, CopyType ctype);
void copy_gpu_inplace(
const array& src,
array& out,
CopyType ctype,
const Stream& s);
void copy_gpu_inplace(
const array& in,
array& out,
const std::vector<int64_t>& istride,
int64_t ioffset,
CopyType ctype,
const Stream& s);
} // namespace mlx::core

View File

@@ -20,9 +20,9 @@ namespace mlx::core::metal {
namespace {
// TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
static constexpr const char* default_mtllib_path = METAL_PATH;
constexpr const char* default_mtllib_path = METAL_PATH;
auto load_device() {
auto devices = MTL::CopyAllDevices();

View File

@@ -16,7 +16,7 @@ namespace mlx::core {
namespace {
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
} // namespace
@@ -201,15 +201,12 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
if (index_nd1_specialization) {
bool upd_col_contiguous = upd.flags().col_contiguous;
compute_encoder->setBytes(
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {

View File

@@ -8,7 +8,6 @@ set(
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
)
@@ -24,9 +23,11 @@ set(
"gemv"
"quantized"
"random"
"reduce"
"rms_norm"
"layer_norm"
"rope"
"scan"
"scaled_dot_product_attention"
"softmax"
"sort"
"ternary"
@@ -68,6 +69,15 @@ foreach(KERNEL ${STEEL_KERNELS})
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal)
file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h)
foreach(KERNEL ${REDUCE_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib

View File

@@ -1,29 +1,29 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename U>
[[kernel]] void copy_s(
device const T* src,
device U* dst,
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v(
device const T* src,
device U* dst,
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src,
device U* dst,
constant const size_t& src_stride,
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
@@ -31,61 +31,61 @@ template <typename T, typename U>
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src,
device U* dst,
constant const size_t src_strides[2],
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * index.y;
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src,
device U* dst,
constant const size_t src_strides[3],
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src,
device U* dst,
constant const int src_shape[DIM],
constant const size_t src_strides[DIM],
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g(
device const T* src,
device U* dst,
constant const int* src_shape,
constant const size_t* src_strides,
constant const int& ndim,
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src,
device U* dst,
constant const size_t& src_stride,
constant const size_t& dst_stride,
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
@@ -94,10 +94,10 @@ template <typename T, typename U>
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src,
device U* dst,
constant const size_t src_strides[2],
constant const size_t dst_strides[2],
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
@@ -106,10 +106,10 @@ template <typename T, typename U>
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src,
device U* dst,
constant const size_t src_strides[3],
constant const size_t dst_strides[3],
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
@@ -118,11 +118,11 @@ template <typename T, typename U>
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src,
device U* dst,
constant const int src_shape[DIM],
constant const size_t src_strides[DIM],
constant const size_t dst_strides[DIM],
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
@@ -131,12 +131,12 @@ template <typename T, typename U, int DIM>
template <typename T, typename U>
[[kernel]] void copy_gg(
device const T* src,
device U* dst,
constant const int* src_shape,
constant const size_t* src_strides,
constant const size_t* dst_strides,
constant const int& ndim,
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
@@ -146,70 +146,70 @@ template <typename T, typename U>
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] \
[[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src, \
device otype* dst, \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void copy_g_nd<itype, otype, dims>( \
device const itype* src, \
device otype* dst, \
constant const int src_shape[dims], \
constant const size_t src_strides[dims], \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] \
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
device const itype* src, \
device otype* dst, \
constant const int src_shape[dims], \
constant const size_t src_strides[dims], \
constant const size_t dst_strides[dims], \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] \
[[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t& src_stride, \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[2], \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[3], \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] \
[[kernel]] void copy_gg_nd1<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t& src_stride, \
constant const size_t& dst_stride, \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] \
[[kernel]] void copy_gg_nd2<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[2], \
constant const size_t dst_strides[2], \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] \
[[kernel]] void copy_gg_nd3<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const size_t src_strides[3], \
constant const size_t dst_strides[3], \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
@@ -218,21 +218,21 @@ template <typename T, typename U>
#define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] \
[[kernel]] void copy_g<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const int* src_shape, \
constant const size_t* src_strides, \
constant const int& ndim, \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] \
[[kernel]] void copy_gg<itype, otype>( \
device const itype* src, \
device otype* dst, \
constant const int* src_shape, \
constant const size_t* src_strides, \
constant const size_t* dst_strides, \
constant const int& ndim, \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_all(tname, itype, otype) \

View File

@@ -14,3 +14,5 @@ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
@@ -22,7 +22,8 @@ template <
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN > /* Thread cols (in elements) */
const int TN , /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
struct GEMVKernel {
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
@@ -48,11 +49,16 @@ struct GEMVKernel {
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
static METAL_FUNC void run(
const device T* mat,
const device T* in_vec,
device T* out_vec,
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& bias_stride [[buffer(14)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
@@ -81,7 +87,7 @@ struct GEMVKernel {
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
// Advance matrix
mat += out_row * in_vec_size;
mat += out_row * marix_ld;
// Loop over in_vec in blocks of BN * TN
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
@@ -124,14 +130,14 @@ struct GEMVKernel {
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * in_vec_size + bn + tn];
inter[tn] = mat[tm * marix_ld + bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
inter[tn] = mat[tm * in_vec_size + col_idx];
inter[tn] = mat[tm * marix_ld + col_idx];
}
}
@@ -154,7 +160,13 @@ struct GEMVKernel {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
if(kDoAxpby) {
out_vec[out_row + tm] =
static_cast<T>(alpha) * result[tm] +
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
} else {
out_vec[out_row + tm] = result[tm];
}
}
}
@@ -172,7 +184,8 @@ template <
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN > /* Thread cols (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
struct GEMVTKernel {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
@@ -197,11 +210,16 @@ struct GEMVTKernel {
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
static METAL_FUNC void run(
const device T* mat,
const device T* in_vec,
device T* out_vec,
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& bias_stride [[buffer(14)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
@@ -245,7 +263,7 @@ struct GEMVTKernel {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
@@ -257,7 +275,7 @@ struct GEMVTKernel {
v_coeff[tm] = in_vec[bm + tm];
for(int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
@@ -292,13 +310,17 @@ struct GEMVTKernel {
#pragma clang loop unroll(full)
for(int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
if(kDoAxpby) {
out_vec[out_col + j] =
static_cast<T>(alpha) * result[j] +
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
} else {
out_vec[out_col + j] = result[j];
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
@@ -310,78 +332,64 @@ template <
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch, /* Batch ndim > 1 */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* bias_batch_stride [[buffer(13)]],
const constant int& bias_stride [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
}
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_nc(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides_vec [[buffer(7)]],
const device size_t* nc_strides_mat [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
if(kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if(kDoAxpby) {
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if(kDoAxpby) {
bias += tid.z * bias_batch_stride[0];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory,
tid,
lid,
@@ -392,41 +400,34 @@ template <
}
#define instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void gemv<itype, bm, bn, tm, tn>( \
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
[[kernel]] void gemv<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& vector_batch_stride [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_nc(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
[[kernel]] void gemv_nc<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides_vec [[buffer(7)]], \
const device size_t* nc_strides_mat [[buffer(8)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
instantiate_gemv_c(name, itype, bm, bn, tm, tn) \
instantiate_gemv_nc(name, itype, bm, bn, tm, tn)
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
@@ -446,77 +447,64 @@ template <
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch, /* Batch ndim > 1 */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& vector_batch_stride [[buffer(5)]],
const constant int& matrix_batch_stride [[buffer(6)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* bias_batch_stride [[buffer(13)]],
const constant int& bias_stride [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += tid.z * vector_batch_stride;
mat += tid.z * matrix_batch_stride;
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
}
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t_nc(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides_vec [[buffer(7)]],
const device size_t* nc_strides_mat [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
in_vec += elem_to_loc(tid.z, nc_shape, nc_strides_vec, nc_dim);
mat += elem_to_loc(tid.z, nc_shape, nc_strides_mat, nc_dim);
if(kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if(kDoAxpby) {
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if(kDoAxpby) {
bias += tid.z * bias_batch_stride[0];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory,
tid,
lid,
@@ -526,41 +514,34 @@ template <
}
#define instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void gemv_t<itype, bm, bn, tm, tn>( \
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
[[kernel]] void gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& vector_batch_stride [[buffer(5)]], \
const constant int& matrix_batch_stride [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc")]] \
[[kernel]] void gemv_t_nc<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* vec [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant int& in_vec_size [[buffer(3)]], \
const constant int& out_vec_size [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides_vec [[buffer(7)]], \
const device size_t* nc_strides_mat [[buffer(8)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
instantiate_gemv_t_c(name, itype, bm, bn, tm, tn) \
instantiate_gemv_t_nc(name, itype, bm, bn, tm, tn)
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1)
#define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \

View File

@@ -0,0 +1,251 @@
// Copyright © 2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void layer_norm_single_row(
const device T* x,
const device T* w,
const device T* b,
device T* out,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
constant uint& b_stride,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float sumx = 0;
float sumx2 = 0;
float thread_x[N_READS];
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
x += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
b += b_stride * lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = x[i];
sumx2 += thread_x[i] * thread_x[i];
sumx += thread_x[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = x[i];
sumx2 += thread_x[i] * thread_x[i];
sumx += thread_x[i];
}
}
}
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
// Write the outputs
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
}
}
}
}
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void layer_norm_looped(
const device T* x,
const device T* w,
const device T* b,
device T* out,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
constant uint& b_stride,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float sumx = 0;
float sumx2 = 0;
constexpr int SIMD_SIZE = 32;
threadgroup float local_sumx[SIMD_SIZE];
threadgroup float local_sumx2[SIMD_SIZE];
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
x += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
b += b_stride * lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
sumx2 += xi * xi;
sumx += xi;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
sumx2 += xi * xi;
sumx += xi;
}
}
}
}
sumx = simd_sum(sumx);
sumx2 = simd_sum(sumx2);
// Initialize shared memory
if (simd_group_id == 0) {
local_sumx[simd_lane_id] = 0;
local_sumx2[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sumx[simd_group_id] = sumx;
local_sumx2[simd_group_id] = sumx2;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
sumx = simd_sum(local_sumx[simd_lane_id]);
sumx2 = simd_sum(local_sumx2[simd_lane_id]);
if (simd_lane_id == 0) {
float mean = sumx / axis_size;
float variance = sumx2 / axis_size - mean * mean;
local_mean[0] = mean;
local_normalizer[0] = metal::precise::rsqrt(variance + eps);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = local_mean[0];
float normalizer = local_normalizer[0];
// Write the outputs
out += gid * axis_size + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = (x[r + i] - mean) * normalizer;
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = (x[r + i] - mean) * normalizer;
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
}
}
}
}
}
// clang-format off
#define instantiate_layer_norm_single_row(name, itype) \
template [[host_name("layer_norm" #name)]] [[kernel]] void \
layer_norm_single_row<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm_looped(name, itype) \
template [[host_name("layer_norm_looped" #name)]] [[kernel]] void \
layer_norm_looped<itype>( \
const device itype* x, \
const device itype* w, \
const device itype* b, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
constant uint& b_stride, \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_layer_norm(name, itype) \
instantiate_layer_norm_single_row(name, itype) \
instantiate_layer_norm_looped(name, itype)
instantiate_layer_norm(float32, float)
instantiate_layer_norm(float16, half)
instantiate_layer_norm(bfloat16, bfloat16_t)
// clang-format on

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
@@ -23,7 +23,143 @@ template <> struct AccT<bfloat16_t> {
typedef float acc_t;
};
template <typename T, const int BM, const int BN, const int group_size, const int bits>
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T *x, thread U *x_thread) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
U sum = 0;
if (bits == 2) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
x_thread[i] = x[i];
x_thread[i+1] = x[i+1] / 4.0f;
x_thread[i+2] = x[i+2] / 16.0f;
x_thread[i+3] = x[i+3] / 64.0f;
}
}
else if (bits == 4) {
for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
x_thread[i] = x[i];
x_thread[i+1] = x[i+1] / 16.0f;
x_thread[i+2] = x[i+2] / 256.0f;
x_thread[i+3] = x[i+3] / 4096.0f;
}
}
else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) {
sum += x[i];
x_thread[i] = x[i];
}
}
return sum;
}
template <typename U, int values_per_thread, int bits>
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
U accum = 0;
if (bits == 2) {
for (int i = 0; i < (values_per_thread / 4); i++) {
accum += (
x_thread[4*i] * (w[i] & 0x03)
+ x_thread[4*i+1] * (w[i] & 0x0c)
+ x_thread[4*i+2] * (w[i] & 0x30)
+ x_thread[4*i+3] * (w[i] & 0xc0));
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (values_per_thread / 4); i++) {
accum += (
x_thread[4*i] * (ws[i] & 0x000f)
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
+ x_thread[4*i+3] * (ws[i] & 0xf000));
}
}
else if (bits == 8) {
for (int i = 0; i < values_per_thread; i++) {
accum += x_thread[i] * w[i];
}
}
return scale * accum + sum * bias;
}
template <typename T, int group_size, int bits, int packs_per_thread>
[[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = 32 / bits;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.z * in_vec_size + simd_lid * values_per_thread;
y += tid.z * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
w += block_size / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -33,91 +169,101 @@ template <typename T, const int BM, const int BN, const int group_size, const in
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int packs_per_thread = 1;
constexpr int pack_factor = 32 / bits;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
(void)lid;
typedef float U;
constexpr int bitmask = (1 << bits) - 1;
constexpr int el_per_thread = 32 / bits;
constexpr int colgroup = BN * el_per_thread;
constexpr int groups_per_block = colgroup / group_size;
typedef typename AccT<T>::acc_t U;
threadgroup U scales_block[BM * groups_per_block];
threadgroup U biases_block[BM * groups_per_block];
threadgroup U x_block[colgroup];
thread uint32_t w_local;
thread U result = 0;
thread U scale = 1;
thread U bias = 0;
thread U x_thread[el_per_thread];
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size / el_per_thread;
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
int out_row = tid.y * BM + simd_gid;
w += out_row * in_vec_size_w;
scales += out_row * in_vec_size_g;
biases += out_row * in_vec_size_g;
x += tid.z * in_vec_size;
y += tid.z * out_vec_size;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup;
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
if (out_row >= out_vec_size) {
return;
}
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=colgroup) {
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<el_per_thread; j++) {
x_block[simd_lid * el_per_thread + j] = x[i + simd_lid * el_per_thread + j];
}
}
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
scales_block[simd_gid * groups_per_block + j] = scales[i / group_size + j];
}
#pragma clang loop unroll(full)
for (int j=0; j<groups_per_block; j++) {
biases_block[simd_gid * groups_per_block + j] = biases[i / group_size + j];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// In this case we need to properly guard all our reads because there isn't
// even 1 tile in the matrix
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.z * in_vec_size + simd_lid * values_per_thread;
y += tid.z * out_vec_size + out_row;
// Load in_vec, scale, bias to registers
#pragma clang loop unroll(full)
for (int j=0; j<el_per_thread; j++) {
x_thread[j] = x_block[simd_lid*el_per_thread + j];
for (int k = 0; k < in_vec_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; out_row + row < out_vec_size; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
w += block_size / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
}
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
// Load the matrix elements
w_local = w[i / el_per_thread + simd_lid];
// Do all the work.
#pragma clang loop unroll(full)
for (int k=0; k<el_per_thread; k++) {
result += (scale * static_cast<U>(w_local & bitmask) + bias) * x_thread[k];
w_local >>= bits;
for (int row = 0; out_row + row < out_vec_size; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
// Accumulate in the simdgroup
result = simd_sum(result);
// In this case the last tile is moved back to redo some output values
else {
w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
x += tid.z * in_vec_size + simd_lid * values_per_thread;
y += tid.z * out_vec_size + used_out_row;
// Store the result
if (simd_lid == 0) {
y[out_row] = static_cast<T>(result);
for (int k = 0; k < in_vec_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
w += block_size / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
}
@@ -532,9 +678,38 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
}
#define instantiate_qmv(name, itype, group_size, bits) \
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmv<itype, 32, 32, group_size, bits>( \
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits "_fast")]] \
[[kernel]] void qmv_fast<itype, group_size, bits, packs_per_thread>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
const device itype* x [[buffer(3)]], \
device itype* y [[buffer(4)]], \
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmv_fast_types(group_size, bits, packs_per_thread) \
instantiate_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
instantiate_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
instantiate_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread)
instantiate_qmv_fast_types(128, 2, 1)
instantiate_qmv_fast_types(128, 4, 2)
instantiate_qmv_fast_types(128, 8, 2)
instantiate_qmv_fast_types( 64, 2, 1)
instantiate_qmv_fast_types( 64, 4, 2)
instantiate_qmv_fast_types( 64, 8, 2)
instantiate_qmv_fast_types( 32, 2, 1)
instantiate_qmv_fast_types( 32, 4, 2)
instantiate_qmv_fast_types( 32, 8, 2)
#define instantiate_qmv(name, itype, group_size, bits) \
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmv<itype, group_size, bits>( \
const device uint32_t* w [[buffer(0)]], \
const device itype* scales [[buffer(1)]], \
const device itype* biases [[buffer(2)]], \
@@ -543,7 +718,6 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
const constant int& in_vec_size [[buffer(5)]], \
const constant int& out_vec_size [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint lid [[thread_index_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);

View File

@@ -1,619 +0,0 @@
// Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
static constant uint8_t simd_size = 32;
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T *out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] \
[[kernel]] void init_reduce<otype, op>( \
device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// All reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_all_reduce(
const device T *in,
const device size_t& in_size,
uint gid,
uint grid_size) {
Op op;
U total_val = Op::init;
if (gid * N_READS < in_size) {
in += gid * N_READS;
int r = 0;
for(; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for(int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for(int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for(int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for(int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
}
return total_val;
}
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
// Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup);
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
total_val = op.simd_reduce(total_val);
// Reduction across threadgroups
if (lid == 0) {
op.atomic_update(out, total_val);
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Reduction across threadgroups
if (lid == 0) {
out[thread_group_id] = total_val;
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] \
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Row atomics
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
inline U per_thread_row_reduce(
const device T *in,
const constant size_t& reduction_size,
const constant size_t& out_size,
const constant int* shape,
const constant size_t* strides,
const constant int& ndim,
uint lsize_x,
uint lid_x,
uint2 tid) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid_x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS*lsize_x) - 1; r++) {
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for(int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize_x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
if(reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for(int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for(int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
return total_val;
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
total_val = op.simd_reduce(total_val);
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if multiple simd groups
if(reduction_size > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
total_val = op.simd_reduce(total_val);
}
// Update output
if (lid.x == 0) {
op.atomic_update(out, total_val, tid.x);
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if(ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
}
// Write row reduce output for threadgroup with 1st thread in thread group
if (lid.x == 0) {
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant int* shape [[buffer(4)]], \
const constant size_t* strides [[buffer(5)]], \
const constant int& ndim [[buffer(6)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
inline U _contiguous_strided_reduce(
const device T *in,
threadgroup U *local_data,
uint in_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
uint2 lsize) {
Op op;
U total_val = Op::init;
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for(uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
uint offset = base_offset + r;
total_val = op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
}
local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
U val = Op::init;
if(lid.y == 0) {
// Perform reduction across columns in thread group
for(uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
}
}
return val;
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
Op op;
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
}
}
}
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] \
[[kernel]] void col_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_reduce(name, itype, otype, op) \
instantiate_all_reduce(name, itype, otype, op) \
instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_general(name, itype, otype, op)
#define instantiate_reduce_no_atomics(name, itype, otype, op) \
instantiate_all_reduce_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_col_reduce_general_no_atomics(name, itype, otype, op)
#define instantiate_same_reduce_no_atomics(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce_no_atomics(name ##tname, type, type, op<type>)
#define instantiate_same_reduce(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>) \
instantiate_reduce(name ##tname, type, type, op<type>)
#define instantiate_reduce_from_types_helper(name, tname, itype, otype, op) \
instantiate_reduce(name ##tname, itype, otype, op)
#define instantiate_reduce_from_types(name, otype, op) \
instantiate_reduce_from_types_helper(name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper(name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper(name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper(name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper(name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper(name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper(name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper(name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper(name, float16, half, otype, op) \
instantiate_reduce_from_types_helper(name, float32, float, otype, op) \
instantiate_reduce_from_types_helper(name, bfloat16, bfloat16_t, otype, op)
// special case bool with larger output type
instantiate_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_same_reduce(sum, uint8, uint8_t, Sum)
instantiate_same_reduce(sum, uint16, uint16_t, Sum)
instantiate_same_reduce(sum, uint32, uint32_t, Sum)
instantiate_same_reduce(sum, int8, int8_t, Sum)
instantiate_same_reduce(sum, int16, int16_t, Sum)
instantiate_same_reduce(sum, int32, int32_t, Sum)
instantiate_same_reduce(sum, float16, half, Sum)
instantiate_same_reduce(sum, float32, float, Sum)
instantiate_same_reduce_no_atomics(sum, int64, int64_t, Sum)
instantiate_same_reduce_no_atomics(sum, uint64, uint64_t, Sum)
instantiate_same_reduce(prod, uint8, uint8_t, Prod)
instantiate_same_reduce(prod, uint16, uint16_t, Prod)
instantiate_same_reduce(prod, uint32, uint32_t, Prod)
instantiate_same_reduce(prod, int8, int8_t, Prod)
instantiate_same_reduce(prod, int16, int16_t, Prod)
instantiate_same_reduce(prod, int32, int32_t, Prod)
instantiate_same_reduce(prod, float16, half, Prod)
instantiate_same_reduce(prod, float32, float, Prod)
instantiate_same_reduce_no_atomics(prod, int64, int64_t, Prod)
instantiate_same_reduce_no_atomics(prod, uint64, uint64_t, Prod)
instantiate_same_reduce(sum, bfloat16, bfloat16_t, Sum)
instantiate_same_reduce(prod, bfloat16, bfloat16_t, Prod)
instantiate_init_reduce(andbool_, bool, And)
instantiate_reduce_from_types(and, bool, And)
instantiate_init_reduce(orbool_, bool, Or)
instantiate_reduce_from_types(or, bool, Or)
// Compiler segfaulted with the names "min" or "max" ...
instantiate_same_reduce(min_, uint8, uint8_t, Min)
instantiate_same_reduce(min_, uint16, uint16_t, Min)
instantiate_same_reduce(min_, uint32, uint32_t, Min)
instantiate_same_reduce(min_, int8, int8_t, Min)
instantiate_same_reduce(min_, int16, int16_t, Min)
instantiate_same_reduce(min_, int32, int32_t, Min)
instantiate_same_reduce(min_, float16, half, Min)
instantiate_same_reduce(min_, float32, float, Min)
instantiate_same_reduce_no_atomics(min_, int64, int64_t, Min)
instantiate_same_reduce_no_atomics(min_, uint64, uint64_t, Min)
instantiate_same_reduce(max_, uint8, uint8_t, Max)
instantiate_same_reduce(max_, uint16, uint16_t, Max)
instantiate_same_reduce(max_, uint32, uint32_t, Max)
instantiate_same_reduce(max_, int8, int8_t, Max)
instantiate_same_reduce(max_, int16, int16_t, Max)
instantiate_same_reduce(max_, int32, int32_t, Max)
instantiate_same_reduce(max_, float16, half, Max)
instantiate_same_reduce(max_, float32, float, Max)
instantiate_same_reduce_no_atomics(max_, int64, int64_t, Max)
instantiate_same_reduce_no_atomics(max_, uint64, uint64_t, Max)
instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min)
instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max)

View File

@@ -0,0 +1,185 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// All reduce helper
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_all_reduce(
const device T* in,
const device size_t& in_size,
uint gid,
uint grid_size) {
Op op;
U total_val = Op::init;
if (gid * N_READS < in_size) {
in += gid * N_READS;
int r = 0;
for (; r < (int)ceildiv(in_size, grid_size * N_READS) - 1; r++) {
U vals[N_READS] = {op.init};
for (int i = 0; i < N_READS; i++) {
vals[i] = static_cast<U>(in[i]);
}
for (int i = 0; i < N_READS; i++) {
total_val = op(vals[i], total_val);
}
in += grid_size * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t curr_idx = (gid + r * (size_t)grid_size) * N_READS;
if (curr_idx < in_size) {
int max_reads = in_size - curr_idx;
T vals[N_READS];
for (int i = 0, idx = 0; i < N_READS; i++, idx++) {
idx = idx < max_reads ? idx : max_reads - 1;
vals[i] = in[idx];
}
for (int i = 0; i < N_READS; i++) {
U val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
}
return total_val;
}
///////////////////////////////////////////////////////////////////////////////
// All reduce kernel
///////////////////////////////////////////////////////////////////////////////
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group
total_val = op.simd_reduce(total_val);
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
// Reduction within thread group
threadgroup_barrier(mem_flags::mem_threadgroup);
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
total_val = op.simd_reduce(total_val);
// Reduction across threadgroups
if (lid == 0) {
op.atomic_update(out, total_val);
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Reduction across threadgroups
if (lid == 0) {
out[thread_group_id] = total_val;
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] \
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name ##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name ##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)

View File

@@ -0,0 +1,253 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Small column reduce kernel
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op>
[[kernel]] void col_reduce_small(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]) {
// Appease the compiler
(void)out_size;
Op op;
U total_val = Op::init;
auto out_idx = tid;
in += elem_to_loc(
out_idx,
shape + non_col_ndim,
strides + non_col_ndim,
ndim - non_col_ndim);
for(uint i = 0; i < non_col_reductions; i++) {
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
U val = static_cast<U>(in[in_idx]);
total_val = op(total_val, val);
}
}
out[out_idx] = total_val;
}
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("col_reduce_small_" #name)]] \
[[kernel]] void col_reduce_small<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce helper
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
METAL_FUNC U _contiguous_strided_reduce(
const device T* in,
threadgroup U* local_data,
uint in_idx,
uint reduction_size,
uint reduction_stride,
uint2 tid,
uint2 lid,
uint2 lsize) {
Op op;
U total_val = Op::init;
uint base_offset = (tid.y * lsize.y + lid.y) * N_READS;
for (uint r = 0; r < N_READS && (base_offset + r) < reduction_size; r++) {
uint offset = base_offset + r;
total_val =
op(static_cast<U>(total_val), in[in_idx + offset * reduction_stride]);
}
local_data[lsize.y * lid.x + lid.y] = total_val;
threadgroup_barrier(mem_flags::mem_threadgroup);
U val = Op::init;
if (lid.y == 0) {
// Perform reduction across columns in thread group
for (uint i = 0; i < lsize.y; i++) {
val = op(val, local_data[lsize.y * lid.x + i]);
}
}
return val;
}
///////////////////////////////////////////////////////////////////////////////
// Column reduce kernel
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
Op op;
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
}
}
}
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
if(out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
out[tgsize_y * tgsize_z * gid.x + tgsize_y * tid.z + tid.y] = val;
}
}
}
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] \
[[kernel]] void col_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or)

View File

@@ -0,0 +1,33 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Reduce init
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T *out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] \
[[kernel]] void init_reduce<otype, op>( \
device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And)
instantiate_init_reduce(orbool_, bool, Or)

View File

@@ -0,0 +1,371 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Small row reductions
///////////////////////////////////////////////////////////////////////////////
// Each thread reduces for one output
template <typename T, typename U, typename Op>
[[kernel]] void row_reduce_general_small(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]) {
Op op;
uint out_idx = lid;
if(out_idx >= out_size) {
return;
}
U total_val = Op::init;
for(short r = 0; r < short(non_row_reductions); r++) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
for(short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
out[out_idx] = total_val;
}
// Each simdgroup reduces for one output
template <typename T, typename U, typename Op>
[[kernel]] void row_reduce_general_med(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
uint out_idx = simd_per_group * tid + simd_group_id;
if(out_idx >= out_size) {
return;
}
U total_val = Op::init;
if(short(non_row_reductions) == 1) {
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
const device T * in_row = in + in_idx;
for(short i = simd_lane_id; i < short(reduction_size); i += 32) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
else if (short(non_row_reductions) >= 32) {
for(short r = simd_lane_id; r < short(non_row_reductions); r+=32) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
for(short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
}
else {
const short n_reductions = short(reduction_size) * short(non_row_reductions);
const short reductions_per_thread = (n_reductions + simd_size - 1) / simd_size;
const short r_st = simd_lane_id / reductions_per_thread;
const short r_ed = short(non_row_reductions);
const short r_jump = simd_size / reductions_per_thread;
const short i_st = simd_lane_id % reductions_per_thread;
const short i_ed = short(reduction_size);
const short i_jump = reductions_per_thread;
if(r_st < r_jump) {
for(short r = r_st; r < r_ed; r += r_jump) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
for(short i = i_st; i < i_ed; i += i_jump) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
}
}
total_val = op.simd_reduce(total_val);
if(simd_lane_id == 0) {
out[out_idx] = total_val;
}
}
#define instantiate_row_reduce_small(name, itype, otype, op) \
template[[host_name("row_reduce_general_small_" #name)]] \
[[kernel]] void row_reduce_general_small<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template[[host_name("row_reduce_general_med_" #name)]] \
[[kernel]] void row_reduce_general_med<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Large row reductions
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
METAL_FUNC U per_thread_row_reduce(
const device T* in,
const constant size_t& reduction_size,
const constant size_t& out_size,
const constant int* shape,
const constant size_t* strides,
const constant int& ndim,
uint lsize_x,
uint lid_x,
uint2 tid) {
Op op;
// Each threadgroup handles 1 reduction
// TODO: Specializing elem_to_loc would be slightly faster
int idx = tid.y * out_size + tid.x;
int extra_offset = elem_to_loc(idx, shape, strides, ndim);
in += extra_offset + lid_x * N_READS;
// The reduction is accumulated here
U total_val = Op::init;
// Loop over the reduction size within thread group
int r = 0;
for (; r < (int)ceildiv(reduction_size, N_READS * lsize_x) - 1; r++) {
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = in[i];
}
for (int i = 0; i < N_READS; i++) {
total_val = op(static_cast<U>(vals[i]), total_val);
}
in += lsize_x * N_READS;
}
// Separate case for the last set as we close the reduction size
size_t reduction_index = (lid_x + (size_t)lsize_x * r) * N_READS;
if (reduction_index < reduction_size) {
int max_reads = reduction_size - reduction_index;
T vals[N_READS];
for (int i = 0; i < N_READS; i++) {
int idx = min(i, max_reads - 1);
vals[i] = static_cast<U>(in[idx]);
}
for (int i = 0; i < N_READS; i++) {
T val = i < max_reads ? vals[i] : Op::init;
total_val = op(static_cast<U>(val), total_val);
}
}
return total_val;
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions;
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
total_val = op.simd_reduce(total_val);
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if multiple simd groups
if(reduction_size > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
total_val = op.simd_reduce(total_val);
}
// Update output
if (lid.x == 0) {
op.atomic_update(out, total_val, tid.x);
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions;
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if(ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
}
// Write row reduce output for threadgroup with 1st thread in thread group
if (lid.x == 0) {
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name ##tname, type, type, op<type>)
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name ##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once

View File

@@ -0,0 +1,71 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#define instantiate_reduce_helper_floats(inst_f, name, op) \
inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
#define instantiate_reduce_helper_uints(inst_f, name, op) \
inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
#define instantiate_reduce_helper_ints(inst_f, name, op) \
inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
inst_f(name, int32, int32_t, op)
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_reduce_helper_ints(inst_f, name, op)
#define instantiate_reduce_ops(inst_f, type_f) \
type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper(inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)

View File

@@ -0,0 +1,14 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
static constant constexpr const uint8_t simd_size = 32;

View File

@@ -0,0 +1,194 @@
// Copyright © 2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void rms_single_row(
const device T* x,
const device T* w,
device T* out,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float acc = 0;
x += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i];
acc += xi * xi;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
float xi = x[i];
acc += xi * xi;
}
}
}
acc = simd_sum(acc);
// Initialize shared memory
if (simd_group_id == 0) {
local_sums[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sums[simd_group_id] = acc;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
acc = simd_sum(local_sums[simd_lane_id]);
if (simd_lane_id == 0) {
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the outputs
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
}
}
}
}
template <typename T, int N_READS = RMS_N_READS>
[[kernel]] void rms_looped(
const device T* x,
const device T* w,
device T* out,
constant float& eps,
constant uint& axis_size,
constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float acc = 0;
x += gid * axis_size + lid * N_READS;
w += w_stride * lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = x[i + r];
acc += xi * xi;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = x[i + r];
acc += xi * xi;
}
}
}
}
acc = simd_sum(acc);
// Initialize shared memory
if (simd_group_id == 0) {
local_sums[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write simd accumulations into shared memory
if (simd_lane_id == 0) {
local_sums[simd_group_id] = acc;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Accumulate over simd groups
if (simd_group_id == 0) {
acc = simd_sum(local_sums[simd_lane_id]);
if (simd_lane_id == 0) {
local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the outputs
out += gid * axis_size + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[r + i] = w[w_stride * (i + r)] *
static_cast<T>(x[r + i] * local_inv_mean[0]);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
out[r + i] = w[w_stride * (i + r)] *
static_cast<T>(x[r + i] * local_inv_mean[0]);
}
}
}
}
}
// clang-format off
#define instantiate_rms_single_row(name, itype) \
template [[host_name("rms" #name)]] [[kernel]] void \
rms_single_row<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms_looped(name, itype) \
template [[host_name("rms_looped" #name)]] [[kernel]] void \
rms_looped<itype>( \
const device itype* x, \
const device itype* w, \
device itype* out, \
constant float& eps, \
constant uint& axis_size, \
constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_rms(name, itype) \
instantiate_rms_single_row(name, itype) \
instantiate_rms_looped(name, itype)
instantiate_rms(float32, float)
instantiate_rms(float16, half)
instantiate_rms(bfloat16, bfloat16_t)
// clang-format on

View File

@@ -0,0 +1,451 @@
#include <metal_stdlib>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
using namespace metal;
template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_t NSIMDGROUPS>
[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]],
const device T *K [[buffer(1)]],
const device T *V [[buffer(2)]],
const device uint64_t& L [[buffer(3)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
constexpr const uint iter_offset = NSIMDGROUPS * 4;
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
uint kv_head_offset_factor = tid.x;
if(is_gqa) {
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
kv_head_offset_factor = tid.x / q_kv_head_ratio;
}
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS;
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
#pragma clang loop unroll(full)
for(uint i = 0; i < 8; i++) {
smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// TODO: multiple query sequence length for speculative decoding
const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
const device T* baseQ = Q + tgroup_query_head_offset;
device T4* simdgroupQueryData = (device T4*)baseQ;
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
float threadAccum[ACCUM_PER_GROUP];
#pragma clang loop unroll(full)
for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) {
threadAccum[threadAccumIndex] = -INFINITY;
}
uint KROW_ACCUM_INDEX = 0;
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
T4 thread_data_x4;
T4 thread_data_y4;
if(!LAST_TILE || LAST_TILE_ALIGNED) {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
#pragma clang loop unroll(full)
for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseK + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
} else {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
const uint START_ROW = tid.y * TILE_SIZE_CONST;
const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset;
for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
}
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
#pragma clang loop unroll(full)
for(size_t i = 0; i < P_VEC4; i++) {
thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]);
simdgroup_barrier(mem_flags::mem_none);
thread_data_y4 = simd_sum(thread_data_x4);
if(simd_lane_id == 0) {
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float groupMax;
float lse = 0.f;
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
float4 pvals[ACCUM_ARRAY_LENGTH];
#pragma clang loop unroll(full)
for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) {
pvals[accum_array_iter] = float4(-INFINITY);
}
if (TILE_SIZE_CONST == 64) {
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
float2 vals = smemPtrFlt2[simd_lane_id];
vals *= params.INV_ALPHA;
float maxval = max(vals.x, vals.y);
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float2 expf_shifted = exp(vals - groupMax);
float sumExpLocal = expf_shifted.x + expf_shifted.y;
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
float2 local_p_hat = expf_shifted / tgroupExpSum;
pvals[0].x = local_p_hat.x;
pvals[0].y = local_p_hat.y;
smemPtrFlt2[simd_lane_id] = float2(0.f);
}
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
if (TILE_SIZE_LARGER_THAN_64) {
float maxval = -INFINITY;
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
#pragma clang loop unroll(full)
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
vals *= params.INV_ALPHA;
pvals[i] = vals;
maxval = fmax3(vals.x, vals.y, maxval);
maxval = fmax3(vals.z, vals.w, maxval);
}
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float sumExpLocal = 0.f;
#pragma clang loop unroll(full)
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = exp(pvals[i] - groupMax);
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
}
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
#pragma clang loop unroll(full)
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = pvals[i] / tgroupExpSum;
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
}
}
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
const size_t v_head_offset = kv_head_offset_factor * L * DK;
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
device T* baseV = (device T*)V + v_offset;
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
if (!LAST_TILE || LAST_TILE_ALIGNED) {
#pragma clang loop unroll(full)
for(size_t col = 0; col < MATRIX_COLS; col++) {
uint matrix_load_loop_iter = 0;
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) {
simdgroup_matrix<T, 8, 8> tmp;
ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
matrix_load_loop_iter++;
};
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
uint loop_iter = 0;
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
#pragma clang loop unroll(full)
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
loop_iter++;
}
}
if (TILE_SIZE_CONST > 64) {
constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128;
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_iter = 0;
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
T row_sum = 0.f;
for(size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[i]);
T val = dot(p_local, v_local);
row_sum += val;
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
loop_iter++;
}
}
}
} else {
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
constexpr const int ROWS_PER_ITER = 8;
#pragma clang loop unroll(full)
for(size_t col = 0; col < MATRIX_COLS; col++) {
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
int32_t tile_start;
for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
simdgroup_matrix<T, 8, 8> tmp;
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false);
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
};
tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
const int32_t INT_L = int32_t(L);
for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) {
if(simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
const uint elems_per_row_gmem = DK;
const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
const uint row_index_v_gmem = row_index;
const uint elems_per_row_smem = TILE_SIZE_CONST;
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
const uint row_index_v_smem = simd_lane_id;
const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem;
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
smemV[scalar_offset_smem] = vdata;
smem_col_index += NSIMDGROUPS;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
for(size_t smem_row_index = simd_group_id;
smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[smem_row_index] = float(row_sum);
}
}
if (TILE_SIZE_CONST > 64) {
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_count = 0;
for(size_t row_index = simd_group_id;
row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) {
T row_sum = 0.f;
for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[tile_iters]);
row_sum += dot(p_local, v_local);
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum);
loop_count++;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if(simd_group_id == 0) {
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
float4 vals = *(oPartialVec4 + simd_lane_id);
device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
oPartialGmemVec4[simd_lane_id] = vals;
}
if(simd_group_id == 0 && simd_lane_id == 0) {
const uint tileIndex = tid.y;
const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex;
p_lse[gmem_partial_scalar_offset] = lse;
p_maxes[gmem_partial_scalar_offset] = groupMax;
}
}
#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \
template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \
[[kernel]] void fast_inference_sdpa_compute_partials_template<itype, itype2, itype4, tile_size, nsimdgroups>( \
const device itype *Q [[buffer(0)]], \
const device itype *K [[buffer(1)]], \
const device itype *V [[buffer(2)]], \
const device uint64_t& L [[buffer(3)]], \
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512);
template <typename T>
void fast_inference_sdpa_reduce_tiles_template(
const device float *O_partials [[buffer(0)]],
const device float *p_lse[[buffer(1)]],
const device float *p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device T* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
constexpr const int DK = 128;
const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
const device float* p_lse_row = p_lse + offset_rows;
const device float* p_rowmax_row = p_maxes + offset_rows;
// reserve some number of registers. this constitutes an assumption on max value of KV TILES.
constexpr const uint8_t reserve = 128;
float p_lse_regs[reserve];
float p_rowmax_regs[reserve];
float weights[reserve];
float true_max = -INFINITY;
for(size_t i = 0; i < params.KV_TILES; i++) {
p_lse_regs[i] = float(*(p_lse_row + i));
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
true_max = fmax(p_rowmax_regs[i], true_max);
weights[i] = exp(p_lse_regs[i]);
}
float denom = 0.f;
for(size_t i = 0; i < params.KV_TILES; i++) {
weights[i] *= exp(p_rowmax_regs[i]-true_max);
denom += weights[i];
}
const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES;
float o_value = 0.f;
for(size_t i = 0; i < params.KV_TILES; i++) {
float val = *(O_partials_with_offset + i * DK + lid.x);
o_value += val * weights[i] / denom;
}
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
O_gmem[lid.x] = T(o_value);
return;
}
kernel void fast_inference_sdpa_reduce_tiles_float(
const device float *O_partials [[buffer(0)]],
const device float *p_lse[[buffer(1)]],
const device float *p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device float* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]])
{
fast_inference_sdpa_reduce_tiles_template<float>(O_partials, p_lse, p_maxes, params,
O, tid, lid);
}
kernel void fast_inference_sdpa_reduce_tiles_half(
const device float *O_partials [[buffer(0)]],
const device float *p_lse[[buffer(1)]],
const device float *p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device half* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]])
{
fast_inference_sdpa_reduce_tiles_template<half>(O_partials, p_lse, p_maxes, params,
O, tid, lid);
}

View File

@@ -0,0 +1,14 @@
//
// scaled_dot_product_attention_params.h
// mlx
#pragma once
struct MLXScaledDotProductAttentionParams {
// Associated dimensions & transposition information
const uint QUERY_SEQUENCE_LENGTH = 1;
const uint N_Q_HEADS = 32;
const uint N_KV_HEADS = 32;
const uint KV_TILES = 1;
const float INV_ALPHA = 0.08838834764831843f;
};

View File

@@ -4,7 +4,7 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
@@ -20,7 +20,6 @@ METAL_FUNC void scatter_1d_index_impl(
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const constant bool& upd_col_contiguous [[buffer(6)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
@@ -33,11 +32,7 @@ METAL_FUNC void scatter_1d_index_impl(
out_idx += idx_val * out_strides[i];
}
if (!upd_col_contiguous) {
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
} else {
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
@@ -48,7 +43,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
@@ -60,7 +54,6 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
out_shape, \
out_strides, \
upd_size, \
upd_col_contiguous, \
idx_buffers, \
gid); \
\
@@ -195,7 +188,6 @@ template [[host_name("scatter_1d_index" name "_" #nidx)]] \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);

View File

@@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_common>
#include <metal_simdgroup>
@@ -224,5 +223,6 @@ template <typename T, int N_READS = SOFTMAX_N_READS>
instantiate_softmax_single_row(name, itype) \
instantiate_softmax_looped(name, itype)
instantiate_softmax(float32, float) instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)
instantiate_softmax(float32, float)
instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)

View File

@@ -140,7 +140,7 @@ struct GEMMKernel {
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* C [[buffer(2)]],
device U* D [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
@@ -167,7 +167,7 @@ struct GEMMKernel {
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col;
D += c_row * params->ldd + c_col;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
@@ -214,7 +214,7 @@ struct GEMMKernel {
}
// Store results to device memory
mma_op.store_result(C, params->ldc);
mma_op.store_result(D, params->ldd);
return;
}
@@ -237,7 +237,7 @@ struct GEMMKernel {
tgp_bn,
leftover_bk);
mma_op.store_result(C, params->ldc);
mma_op.store_result(D, params->ldd);
return;
} else if (tgp_bn == BN) {
@@ -252,7 +252,7 @@ struct GEMMKernel {
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else if (tgp_bm == BM) {
@@ -267,7 +267,7 @@ struct GEMMKernel {
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else {
@@ -282,7 +282,7 @@ struct GEMMKernel {
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
}
}

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
@@ -23,8 +24,10 @@ template <typename T,
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *C [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
device T *D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -36,12 +39,25 @@ template <typename T,
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z;
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
gemm_kernel::run(
A, B, C,
A, B, D,
params,
As, Bs,
simd_lane_id, simd_group_id, tid, lid
@@ -57,8 +73,10 @@ template <typename T,
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *C [[buffer(2)]], \
const constant GEMMParams* params [[buffer(3)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \

View File

@@ -27,7 +27,10 @@ template <typename T,
const device T *B [[buffer(1)]],
const device T *C [[buffer(2)]],
device T *D [[buffer(3)]],
const constant GEMMAddMMParams* params [[buffer(4)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
@@ -50,9 +53,24 @@ template <typename T,
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z;
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
ulong3 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
C += batch_offsets.z;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += addmm_params->batch_stride_c * tid.z;
}
D += params->batch_stride_d * tid.z;
const int tid_y = ((tid.y) << params->swizzle_log) +
@@ -71,9 +89,10 @@ template <typename T,
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col * params->fdc;
D += c_row * params->ldd + c_col;
C += c_row * addmm_params->ldc + c_col * addmm_params->fdc;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
@@ -83,7 +102,7 @@ template <typename T,
int gemm_k_iterations = params->gemm_k_iterations_aligned;
const Epilogue epilogue_op(params->alpha, params->beta);
const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
@@ -121,7 +140,7 @@ template <typename T,
}
// Store results to device memory
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return;
}
@@ -145,7 +164,7 @@ template <typename T,
leftover_bk,
LoopAlignment<true, true, K_aligned>{});
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return;
} else if (tgp_bn == BN) {
@@ -163,7 +182,7 @@ template <typename T,
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
@@ -182,7 +201,7 @@ template <typename T,
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
@@ -201,7 +220,7 @@ template <typename T,
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
}
@@ -219,7 +238,10 @@ template <typename T,
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2)]], \
device itype *D [[buffer(3)]], \
const constant GEMMAddMMParams* params [[buffer(4)]], \
const constant GEMMParams* gemm_params [[buffer(4)]], \
const constant GEMMAddMMParams* params [[buffer(5)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \

View File

@@ -144,9 +144,9 @@ struct BlockMMA {
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const {
METAL_FUNC void store_result(device U* D, const int ldd) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn;
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
@@ -155,22 +155,22 @@ struct BlockMMA {
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
int offset = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C
C[offset] = outs[0];
C[offset + 1] = outs[1];
// Write out D
D[offset] = outs[0];
D[offset + 1] = outs[1];
}
}
}
METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
D += (sm + tm) * ldd + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
@@ -183,15 +183,15 @@ struct BlockMMA {
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
int offset = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
D[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
D[offset + 1] = Epilogue::apply(accum[1]);
}
}
}

View File

@@ -16,17 +16,19 @@ struct GEMMParams {
const int lda;
const int ldb;
const int ldc;
const int ldd;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c;
const int batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
const int batch_ndim;
};
struct GEMMSpiltKParams {
@@ -49,30 +51,13 @@ struct GEMMSpiltKParams {
};
struct GEMMAddMMParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int ldd;
const int fdc;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c;
const int batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
const float alpha;
const float beta;
const int fdc;
};
} // namespace steel

View File

@@ -5,4 +5,41 @@
#include <metal_stdlib>
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
METAL_FUNC ulong2 elem_to_loc_broadcast(
uint elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
int ndim) {
ulong loc_a{0};
ulong loc_b{0};
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
int pos_in_dim = (elem % shape[i]);
elem /= shape[i];
loc_a += pos_in_dim * a_strides[i];
loc_b += pos_in_dim * b_strides[i];
}
return ulong2(loc_a, loc_b);
}
METAL_FUNC ulong3 elem_to_loc_broadcast(
uint elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
int ndim) {
ulong loc_a{0};
ulong loc_b{0};
ulong loc_c{0};
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
int pos_in_dim = (elem % shape[i]);
elem /= shape[i];
loc_a += pos_in_dim * a_strides[i];
loc_b += pos_in_dim * b_strides[i];
loc_c += pos_in_dim * c_strides[i];
}
return ulong3(loc_a, loc_b, loc_c);
}

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@@ -65,12 +65,18 @@ struct Limits<bool> {
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
inline size_t elem_to_loc(
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
///////////////////////////////////////////////////////////////////////////////
// Single Array with generic dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint elem,
device const int* shape,
device const size_t* strides,
device const stride_t* strides,
int ndim) {
size_t loc = 0;
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
@@ -78,12 +84,13 @@ inline size_t elem_to_loc(
return loc;
}
inline size_t elem_to_loc(
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint elem,
constant const int* shape,
constant const size_t* strides,
constant const stride_t* strides,
int ndim) {
size_t loc = 0;
stride_t loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
@@ -91,52 +98,59 @@ inline size_t elem_to_loc(
return loc;
}
template <int NDIM>
inline uint3 elem_to_loc_3_nd(
// Non templated version to handle arbitrary dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM],
constant const size_t c_strides[NDIM]) {
uint3 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
static_cast<uint>(
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
constant const int* shape,
constant const stride_t* strides,
int ndim) {
stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Single Array with fixed N dims
template <typename stride_t>
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
return elem * stride;
}
template <typename stride_t>
METAL_FUNC stride_t
elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
}
template <typename stride_t>
METAL_FUNC stride_t
elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
template <int NDIM>
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM]) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
METAL_FUNC size_t elem_to_loc_nd(
uint elem,
device const int* shape,
device const size_t* strides) {
size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
MLX_MTL_PRAGMA_UNROLL
for (int d = NDIM - 2; d >= 0; --d) {
elem /= shape[d + 1];
loc += (elem % shape[d]) * strides[d];
}
return loc;
}
template <int NDIM>
inline size_t elem_to_loc_nd(
METAL_FUNC size_t elem_to_loc_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t strides[NDIM]) {
@@ -148,33 +162,59 @@ inline size_t elem_to_loc_nd(
return loc;
}
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
return elem * stride;
template <int NDIM>
METAL_FUNC int64_t elem_to_loc_nd(
uint elem,
constant const int shape[NDIM],
constant const int64_t strides[NDIM]) {
int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
MLX_MTL_PRAGMA_UNROLL
for (int d = NDIM - 2; d >= 0; --d) {
elem /= shape[d + 1];
loc += (elem % shape[d]) * strides[d];
}
return loc;
}
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
}
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
// Non templated version to handle arbitrary dims
inline size_t elem_to_loc(
template <int NDIM>
METAL_FUNC int64_t elem_to_loc_nd(
uint3 elem,
constant const int* shape,
constant const size_t* strides,
int ndim) {
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
for (int d = ndim - 3; d >= 0; --d) {
constant const int shape[NDIM],
constant const int64_t strides[NDIM]) {
int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
for (int d = NDIM - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline uint3 elem_to_loc_3_nd(
///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims
METAL_FUNC uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
int ndim) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
METAL_FUNC uint3 elem_to_loc_3_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
@@ -198,18 +238,21 @@ inline uint3 elem_to_loc_3_nd(
return loc;
}
inline uint2 elem_to_loc_2_nd(
///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with fixed N dims
template <int NDIM>
METAL_FUNC uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
int ndim) {
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM]) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
@@ -219,55 +262,26 @@ inline uint2 elem_to_loc_2_nd(
}
template <int NDIM>
inline uint elem_to_loc_nd(
uint elem,
device const int* shape,
device const size_t* strides);
template <>
inline uint elem_to_loc_nd<1>(
uint elem,
device const int* shape,
device const size_t* strides) {
return (elem % shape[0]) * strides[0];
}
template <>
inline uint elem_to_loc_nd<2>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<3>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<4>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[3]) * strides[3];
elem /= shape[3];
loc += (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
METAL_FUNC uint3 elem_to_loc_3_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM],
constant const size_t c_strides[NDIM]) {
uint3 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
static_cast<uint>(
elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
loc.z += l * c_strides[d];
elem.z /= shape[d];
}
return loc;
}

View File

@@ -191,6 +191,70 @@ inline void mps_matmul(
});
}
inline auto collapse_batches(const array& a, const array& b) {
// Get and check the shape for the batched dims
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
if (A_bshape != B_bshape) {
std::ostringstream msg;
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: "
<< "A " << a.shape() << ", B " << b.shape() << ".";
throw std::runtime_error(msg.str());
}
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
auto [batch_shape, batch_strides] =
collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
}
return std::make_tuple(batch_shape, A_batch_stride, B_batch_stride);
}
inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Get and check the shape for the batched dims
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
if (A_bshape != B_bshape || A_bshape != C_bshape) {
std::ostringstream msg;
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: "
<< "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape()
<< ".";
throw std::runtime_error(msg.str());
}
std::vector<size_t> A_bstride{a.strides().begin(), a.strides().end() - 2};
std::vector<size_t> B_bstride{b.strides().begin(), b.strides().end() - 2};
std::vector<size_t> C_bstride{c.strides().begin(), c.strides().end() - 2};
auto [batch_shape, batch_strides] = collapse_contiguous_dims(
A_bshape, std::vector{A_bstride, B_bstride, C_bstride});
auto A_batch_stride = batch_strides[0];
auto B_batch_stride = batch_strides[1];
auto C_batch_stride = batch_strides[2];
if (batch_shape.empty()) {
batch_shape.push_back(1);
A_batch_stride.push_back(0);
B_batch_stride.push_back(0);
C_batch_stride.push_back(0);
}
return std::make_tuple(
batch_shape, A_batch_stride, B_batch_stride, C_batch_stride);
}
} // namespace
///////////////////////////////////////////////////////////////////////////////
@@ -211,22 +275,33 @@ void steel_matmul(
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies) {
std::vector<array>& copies,
std::vector<int> batch_shape /* = {} */,
std::vector<size_t> A_batch_stride /* = {} */,
std::vector<size_t> B_batch_stride /* = {} */) {
using namespace mlx::steel;
// Coalesce (B, M, K) X (K, N) to (B*M, K) X (K, N)
if (batch_size_out > 1 && !transpose_a &&
a.data_size() == batch_size_out * M * K && b.size() == K * N) {
M = M * batch_size_out;
batch_size_out = 1;
if (batch_shape.empty()) {
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
batch_shape = batch_shape_;
A_batch_stride = A_bstride_;
B_batch_stride = B_bstride_;
// Collapse batches into M if needed
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
B_batch_stride.back() == 0) {
M *= batch_shape.back();
batch_size_out = 1;
A_batch_stride = {0};
B_batch_stride = {0};
batch_shape = {1};
}
}
// Account for batch sizes and basic broadcasting
int batch_size_a = a.data_size() / (M * K);
int batch_size_b = b.data_size() / (K * N);
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
int matrix_stride_out = M * N;
/////////////////////////////////////////////////////////////////////////////
@@ -269,18 +344,18 @@ void steel_matmul(
int tm = (M + bm - 1) / bm;
GEMMSpiltKParams params{
M,
N,
K,
lda,
ldb,
N,
tn,
tm,
split_k_partitions,
split_k_partition_stride,
split_k_partition_size,
gemm_k_iterations};
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldc = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int split_k_partitions = */ split_k_partitions,
/* const int split_k_partition_stride = */ split_k_partition_stride,
/* const int split_k_partition_size = */ split_k_partition_size,
/* const int gemm_k_iterations_aligned = */ gemm_k_iterations};
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
@@ -364,19 +439,20 @@ void steel_matmul(
// Prepare steel matmul params
GEMMParams params{
M,
N,
K,
lda,
ldb,
N,
tn,
tm,
matrix_stride_a,
matrix_stride_b,
matrix_stride_out,
swizzle_log,
(K / bk)};
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
// Prepare launch grid params
int tile = 1 << swizzle_log;
@@ -386,37 +462,25 @@ void steel_matmul(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
// Launch only 1 kernel in the case of simple batching / broadcasting
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
(batch_size_a == batch_size_b ||
std::min(batch_size_a, batch_size_b) == 1)) {
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, out, 2);
std::vector<size_t> batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
compute_encoder->setBytes(&params, sizeof(GEMMParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else { // Otherwise launch kernels with set offsets
// Launch kernel
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, out, 3);
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
for (int i = 0; i < batch_size_out; ++i) {
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
}
}
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
@@ -428,12 +492,21 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Return 0s if either input is empty
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
copy_gpu(zero, out, CopyType::Scalar, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
@@ -444,9 +517,9 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
if (sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
} else if (stx == 1) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
@@ -464,8 +537,25 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int N = b.shape(-1);
int K = a.shape(-1);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
auto batch_size_out = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
B_batch_stride.back() == 0) {
M *= batch_shape.back();
batch_size_out = 1;
A_batch_stride = {0};
B_batch_stride = {0};
batch_shape = {1};
}
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
@@ -482,20 +572,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
int batch_size_mat = mat.data_size() / (mat_cols * mat_rows);
int stride_mat = batch_size_mat == 1 ? 0 : mat_cols * mat_rows;
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
int batch_size_vec = vec.data_size() / in_vector_len;
int stride_vec = batch_size_vec == 1 ? 0 : in_vector_len;
int stride_mat = batch_strides_mat.back();
int stride_vec = batch_strides_vec.back();
// Determine if inputs have simple batching / broadcasting
bool contiguous_kernel =
(batch_size_out == std::max(batch_size_mat, batch_size_vec) &&
(batch_size_mat == batch_size_vec ||
std::min(batch_size_mat, batch_size_vec) == 1));
bool contiguous_kernel = (batch_shape.size() == 1);
int nc_dim = out.ndim() - 2;
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
@@ -531,10 +619,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
}
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
if (!contiguous_kernel) {
kname << "_nc";
}
kname << "_nc" << !contiguous_kernel << "_axpby0";
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
@@ -547,25 +632,18 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
set_array_buffer(compute_encoder, mat, 0);
set_array_buffer(compute_encoder, vec, 1);
set_array_buffer(compute_encoder, out, 2);
set_array_buffer(compute_encoder, out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 3);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
if (contiguous_kernel) {
compute_encoder->setBytes(&stride_vec, sizeof(int), 5);
compute_encoder->setBytes(&stride_mat, sizeof(int), 6);
} else {
// In case of complex broadcasting, we consider the shape[:-2] and
// strides [:-2] to determine the location of a batch
// nc_dim = out.ndim() - 2
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(out.shape().data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(
vec.strides().data(), nc_dim * sizeof(size_t), 7);
compute_encoder->setBytes(
mat.strides().data(), nc_dim * sizeof(size_t), 8);
}
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
compute_encoder->setBytes(
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
compute_encoder->setBytes(
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
@@ -573,7 +651,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
/////////////////////////////////////////////////////////////////////////////
// Gemm specialization
@@ -598,20 +675,23 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
}
return steel_matmul(
s,
d,
a,
b,
out,
M,
N,
K,
batch_size_out,
a_cols,
b_cols,
a_transposed,
b_transposed,
copies);
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
/* const array& a = */ a,
/* const array& b = */ b,
/* array& out = */ out,
/* int M = */ M,
/* int N = */ N,
/* int K = */ K,
/* int batch_size_out = */ batch_size_out,
/* int lda = */ a_cols,
/* int ldb = */ b_cols,
/* bool transpose_a = */ a_transposed,
/* bool transpose_b = */ b_transposed,
/* std::vector<array>& = */ copies,
/* std::vector<int> batch_shape = */ batch_shape,
/* std::vector<size_t> A_batch_stride = */ A_batch_stride,
/* std::vector<size_t> B_batch_stride = */ B_batch_stride);
}
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -637,9 +717,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
if (sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
} else if (stx == 1) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
@@ -657,33 +737,151 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
int N = b.shape(-1);
int K = a.shape(-1);
auto batch_size_out = out.size() / (M * N);
array c = c_pre;
int ldc = c.strides()[c.ndim() - 2];
int fdc = c.strides()[c.ndim() - 1];
int matrix_stride_c = c.ndim() <= 2 ? 0 : c.strides()[c.ndim() - 3];
int lda = a_cols;
int ldb = b_cols;
int ldd = N;
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
collapse_batches(a, b, c);
auto batch_size_out = out.size() / (M * N);
// Collapse batches into M if needed
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&
B_batch_stride.back() == 0) {
M *= batch_shape.back();
batch_size_out = 1;
A_batch_stride = {0};
B_batch_stride = {0};
C_batch_stride = {0};
batch_shape = {1};
}
int matrix_stride_out = M * N;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
int stride_mat = batch_strides_mat.back();
int stride_vec = batch_strides_vec.back();
// Determine if inputs have simple batching / broadcasting
bool contiguous_kernel = (batch_shape.size() == 1);
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int bm, bn, n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
bm = 8;
bn = 8;
if (out_vector_len >= 24576) {
bn = 128;
} else if (out_vector_len >= 16384) {
bn = 64;
} else if (out_vector_len >= 8192) {
bn = 16;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * tn;
kname << "gemv_t_" << type_to_name(out);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
bn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * tm;
kname << "gemv_" << type_to_name(out);
}
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
kname << "_nc" << !contiguous_kernel << "_axpby1";
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(bn, bm, 1);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
set_array_buffer(compute_encoder, mat, 0);
set_array_buffer(compute_encoder, vec, 1);
set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder->setBytes(&alpha_, sizeof(float), 7);
compute_encoder->setBytes(&beta_, sizeof(float), 8);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
compute_encoder->setBytes(
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
compute_encoder->setBytes(
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
compute_encoder->setBytes(
C_batch_stride.data(), batch_ndim * sizeof(size_t), 13);
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
using namespace mlx::steel;
// Account for batch sizes and basic broadcasting
int batch_size_a = a.data_size() / (M * K);
int batch_size_b = b.data_size() / (K * N);
int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K;
int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N;
int matrix_stride_out = M * N;
/////////////////////////////////////////////////////////////////////////////
// Split K specialization
int _tm = M / 16;
int _tn = N / 16;
int _tk = K / 16;
/////////////////////////////////////////////////////////////////////////////
// Split K specialization
if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) {
int bm = M < 40 ? 16 : 32;
int bn = N < 40 ? 16 : 32;
@@ -809,25 +1007,29 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams gemm_params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
GEMMAddMMParams params{
M,
N,
K,
lda,
ldb,
ldc,
N,
tn,
tm,
matrix_stride_a,
matrix_stride_b,
matrix_stride_c,
matrix_stride_out,
swizzle_log,
(K / bk),
alpha_,
beta_,
fdc};
/* const int ldc = */ ldc,
/* const int fdc = */ fdc,
/* const int batch_stride_c = */ int(C_batch_stride.back()),
/* const float alpha = */ alpha_,
/* const float beta = */ beta_};
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
@@ -836,40 +1038,27 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
// Launch only 1 kernel in the case of simple batching / broadcasting
if (batch_size_out == std::max(batch_size_a, batch_size_b) &&
(batch_size_a == batch_size_b ||
std::min(batch_size_a, batch_size_b) == 1)) {
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3);
std::vector<size_t> batch_strides = A_batch_stride;
batch_strides.insert(
batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
batch_strides.insert(
batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 4);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
} else { // Otherwise launch kernels with set offsets
// Launch kernel
set_array_buffer(compute_encoder, a, 0);
set_array_buffer(compute_encoder, b, 1);
set_array_buffer(compute_encoder, c, 2);
set_array_buffer(compute_encoder, out, 3);
MTL::Size grid_dims_single = MTL::Size(tn, tm, 1);
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 5);
for (int i = 0; i < batch_size_out; ++i) {
auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides());
auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides());
auto c_off = elem_to_loc(M * N * i, c.shape(), c.strides());
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto c_buf = static_cast<const MTL::Buffer*>(c.buffer().ptr());
auto out_buf = static_cast<const MTL::Buffer*>(out.buffer().ptr());
compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0);
compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1);
compute_encoder->setBuffer(c_buf, c_off * c.itemsize(), 2);
compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 3);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 4);
compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims);
}
}
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });

View File

@@ -26,6 +26,9 @@ void steel_matmul(
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies);
std::vector<array>& copies,
std::vector<int> batch_shape = {},
std::vector<size_t> A_batch_stride = {},
std::vector<size_t> B_batch_stride = {});
} // namespace mlx::core

View File

@@ -12,8 +12,51 @@
namespace mlx::core::metal {
bool is_available();
bool cache_enabled(void);
void set_cache_enabled(bool enabled);
/* Get the actively used memory in bytes.
*
* Note, this will not always match memory use reported by the system because
* it does not include cached memory buffers.
* */
size_t get_active_memory();
/* Get the peak amount of used memory in bytes.
*
* The maximum memory used is recorded from the beginning of the program
* execution.
* */
size_t get_peak_memory();
/* Get the cache size in bytes.
*
* The cache includes memory not currently used that has not been returned
* to the system allocator.
* */
size_t get_cache_memory();
/* Set the memory limit.
* Calls to malloc will wait on scheduled tasks if the limit is exceeded. If
* there are no more scheduled tasks an error will be raised if relaxed
* is false or memory will be allocated (including the potential for
* swap) if relaxed is true.
*
* The memory limit defaults to 1.5 times the maximum recommended working set
* size reported by the device.
*
* Returns the previous memory limit.
* */
size_t set_memory_limit(size_t limit, bool relaxed = true);
/* Set the free cache limit.
* If using more than the given limit, free memory will be reclaimed
* from the cache on the next allocation. To disable the cache,
* set the limit to 0.
*
* The cache limit defaults to the memory limit.
*
* Returns the previous cache limit.
* */
size_t set_cache_limit(size_t limit);
void new_stream(Stream stream);
std::shared_ptr<void> new_scoped_memory_pool();

View File

@@ -0,0 +1,185 @@
// Copyright © 2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
namespace mlx::core::fast {
void RMSNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& out = outputs[0];
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& x = check_input(inputs[0]);
const array& w = inputs[1];
if (x.is_donatable()) {
out.move_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
std::string op_name = "rms";
if (axis_size > looped_limit) {
op_name += "_looped";
}
op_name += type_to_name(out);
auto compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
uint32_t w_stride = w.strides()[0];
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, out, 2);
compute_encoder->setBytes(&eps_, sizeof(float), 3);
compute_encoder->setBytes(&axis_size, sizeof(int), 4);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 5);
compute_encoder->setThreadgroupMemoryLength(
16 * 8, 0); // minimum of 16 bytes
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
void LayerNorm::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = metal::device(s.device);
auto& out = outputs[0];
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
bool no_copy = x.strides()[x.ndim() - 1] == 1;
if (x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& x = check_input(inputs[0]);
const array& w = inputs[1];
const array& b = inputs[2];
if (x.is_donatable()) {
out.move_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = RMS_N_READS;
const int looped_limit = RMS_LOOPED_LIMIT;
std::string op_name = "layer_norm";
if (axis_size > looped_limit) {
op_name += "_looped";
}
op_name += type_to_name(out);
auto compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
uint32_t b_stride = (b.ndim() == 1) ? b.strides()[0] : 0;
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(
compute_encoder, x.data_shared_ptr() == nullptr ? out : x, 0);
set_array_buffer(compute_encoder, w, 1);
set_array_buffer(compute_encoder, b, 2);
set_array_buffer(compute_encoder, out, 3);
compute_encoder->setBytes(&eps_, sizeof(float), 4);
compute_encoder->setBytes(&axis_size, sizeof(int), 5);
compute_encoder->setBytes(&w_stride, sizeof(uint32_t), 6);
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core::fast

View File

@@ -17,7 +17,7 @@ namespace mlx::core {
namespace {
static constexpr int METAL_MAX_INDEX_ARRAYS = 10;
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
void binary_op(
const std::vector<array>& inputs,
@@ -696,6 +696,10 @@ void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min");
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Floor::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op(inputs, out, "floor");
}
@@ -805,13 +809,13 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (in.flags().row_contiguous) {
auto flags = in.flags();
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.copy_shared_buffer(in, out.strides(), flags, in.data_size());
} else {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
copy_gpu(in, out, CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
}
@@ -861,7 +865,73 @@ void Sqrt::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Slice::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
assert(inputs.size() == 1);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
// Do copy if needed
if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
copy_gpu_inplace(
/* const array& in = */ in,
/* array& out = */ out,
/* const std::vector<int>& data_shape = */ out.shape(),
/* const std::vector<stride_t>& i_strides = */ inp_strides,
/* const std::vector<stride_t>& o_strides = */ ostrides,
/* int64_t i_offset = */ data_offset,
/* int64_t o_offset = */ 0,
/* CopyType ctype = */ CopyType::General,
/* const Stream& s = */ stream());
} else {
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, out);
}
}
void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
auto& in = inputs[0];
auto& upd = inputs[1];
if (upd.size() == 0) {
out.copy_shared_buffer(in);
return;
}
// Check if materialization is needed
auto ctype = in.flags().contiguous && in.size() == in.data_size()
? CopyType::Vector
: CopyType::General;
copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream());
// Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out);
// Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()};
copy_gpu_inplace<int64_t>(
/* const array& src = */ upd,
/* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(),
/* const std::vector<stride_t>& i_strides = */ upd_strides,
/* const std::vector<stride_t>& o_strides = */ out_strides,
/* int64_t i_offset = */ 0,
/* int64_t o_offset = */ data_offset,
/* CopyType ctype = */ CopyType::GeneralGeneral,
/* const Stream& s = */ stream());
}
void StopGradient::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -890,4 +960,14 @@ void QRF::eval_gpu(
throw std::runtime_error("[QRF::eval_gpu] Metal QR factorization NYI.");
}
void SVD::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[SVD::eval_gpu] Metal SVD NYI.");
}
void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {
throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI.");
}
} // namespace mlx::core

View File

@@ -41,8 +41,35 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int B = x.size() / D;
int O = out.shape(-1);
if (transpose_) {
// Route to the fast qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_ << "_fast";
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
set_array_buffer(compute_encoder, w, 0);
set_array_buffer(compute_encoder, scales, 1);
set_array_buffer(compute_encoder, biases, 2);
set_array_buffer(compute_encoder, x, 3);
set_array_buffer(compute_encoder, out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmv kernel
if (B < 6) {
else if (B < 6) {
std::ostringstream kname;
kname << "qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
@@ -52,9 +79,9 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = std::min(32, O);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
set_array_buffer(compute_encoder, w, 0);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
@@ -130,15 +130,8 @@ void row_reduce_general_dispatch(
const Stream& s) {
Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"row_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("row_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Prepare the arguments for the kernel
int n_reads = REDUCE_N_READS;
size_t reduction_size = plan.shape.back();
auto shape = plan.shape;
auto strides = plan.strides;
@@ -160,32 +153,72 @@ void row_reduce_general_dispatch(
}
int ndim = shape.size();
// Each thread group is responsible for 1 output
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Determine dispatch kernel
std::ostringstream kname;
// Align thread group size with simd_size
uint simd_size = kernel->threadExecutionWidth();
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
bool is_small = non_row_reductions * reduction_size < 32;
bool is_med = non_row_reductions * reduction_size <= 256;
is_out_64b_int &= !is_small && !is_med;
// Launch enough thread groups for each output
size_t n_threads = out.size() * thread_group_size;
MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
std::string small_desc = "_";
if (is_small) {
small_desc = "_small_";
} else if (is_med) {
small_desc = "_med_";
}
if (is_out_64b_int == false || non_row_reductions == 1) {
small_desc = is_out_64b_int ? "_no_atomics_" : small_desc;
kname << "row_reduce_general" << small_desc << op_name << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Get dispatch grid dims
MTL::Size grid_dims;
MTL::Size group_dims;
// Each thread handles one output
if (is_small) {
grid_dims = MTL::Size(out.size(), 1, 1);
group_dims = MTL::Size(std::min(1024ul, out.size()), 1, 1);
}
// Each simdgroup handles one output
else if (is_med) {
grid_dims = MTL::Size(out.size() * 32, 1, 1);
group_dims = MTL::Size(std::min(8ul, out.size()) * 32, 1, 1);
}
// Each theadgroup handles one output
else {
int n_reads = REDUCE_N_READS;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
uint simd_size = kernel->threadExecutionWidth();
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
size_t n_threads = out.size() * thread_group_size;
grid_dims = MTL::Size(n_threads, non_row_reductions, 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
}
// Dispatch kernel
if (!is_out_64b_int || non_row_reductions == 1) {
// Set the arguments for the kernel
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
} else {
@@ -203,10 +236,11 @@ void row_reduce_general_dispatch(
set_array_buffer(compute_encoder, intermediate, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4);
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->dispatchThreads(grid_dims, group_dims);
// Set up second dispatch
@@ -230,24 +264,27 @@ void row_reduce_general_dispatch(
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(&non_row_reductions, sizeof(size_t), 4);
compute_encoder->setBytes(
new_shape.data(), new_shape.size() * sizeof(int), 4);
new_shape.data(), new_shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
new_strides.data(), new_strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// Each thread group is responsible for 1 output
thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int n_reads = REDUCE_N_READS;
size_t thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
thread_group_size =
std::min((reduction_size + n_reads - 1) / n_reads, thread_group_size);
// Align thread group size with simd_size
uint simd_size = kernel->threadExecutionWidth();
thread_group_size =
(thread_group_size + simd_size - 1) / simd_size * simd_size;
assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup());
// Launch enough thread groups for each output
n_threads = thread_group_size;
size_t n_threads = thread_group_size;
grid_dims = MTL::Size(n_threads, out.size(), 1);
group_dims = MTL::Size(thread_group_size, 1, 1);
@@ -270,13 +307,6 @@ void strided_reduce_general_dispatch(
metal::Device& d,
const Stream& s) {
Dtype out_dtype = out.dtype();
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Prepare the arguments for the kernel
size_t reduction_size = plan.shape.back();
@@ -290,6 +320,11 @@ void strided_reduce_general_dispatch(
for (auto s : shape) {
non_col_reductions *= static_cast<size_t>(s);
}
std::vector<int> non_col_shapes = shape;
std::vector<size_t> non_col_strides = strides;
int non_col_ndim = shape.size();
auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes);
for (auto s : rem_shape) {
shape.push_back(s);
@@ -299,6 +334,54 @@ void strided_reduce_general_dispatch(
}
int ndim = shape.size();
// Specialize for small dims
if (reduction_size * non_col_reductions < 16) {
// Select kernel
auto kernel =
d.get_kernel("col_reduce_small_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Select block dims
MTL::Size grid_dims = MTL::Size(out_size, 1, 1);
MTL::Size group_dims = MTL::Size(256ul, 1, 1);
if (non_col_ndim == 0) {
non_col_shapes = {1};
non_col_strides = {1};
}
// Encode arrays
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3);
compute_encoder->setBytes(&out_size, sizeof(size_t), 4);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
strides.data(), strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
compute_encoder->setBytes(&non_col_reductions, sizeof(size_t), 8);
compute_encoder->setBytes(
non_col_shapes.data(), non_col_shapes.size() * sizeof(int), 9);
compute_encoder->setBytes(
non_col_strides.data(), non_col_shapes.size() * sizeof(size_t), 10);
compute_encoder->setBytes(&non_col_ndim, sizeof(int), 11);
// Dispatch threads
compute_encoder->dispatchThreads(grid_dims, group_dims);
return;
}
// Select kernel
bool is_out_64b_int = is_64b_int(out_dtype);
auto kernel = (is_out_64b_int)
? d.get_kernel(
"col_reduce_general_no_atomics_" + op_name + type_to_name(in))
: d.get_kernel("col_reduce_general_" + op_name + type_to_name(in));
compute_encoder->setComputePipelineState(kernel);
// Select block dimensions
// Each thread reads 16 inputs to give it more work
uint n_inputs_per_thread = REDUCE_N_READS;
@@ -417,11 +500,12 @@ void strided_reduce_general_dispatch(
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2);
compute_encoder->setBytes(&out_size, sizeof(size_t), 3);
compute_encoder->setBytes(&reduction_size, sizeof(size_t), 4);
compute_encoder->setBytes(
new_shape.data(), new_shape.size() * sizeof(int), 4);
new_shape.data(), new_shape.size() * sizeof(int), 5);
compute_encoder->setBytes(
new_strides.data(), new_strides.size() * sizeof(size_t), 5);
compute_encoder->setBytes(&ndim, sizeof(int), 6);
new_strides.data(), new_strides.size() * sizeof(size_t), 6);
compute_encoder->setBytes(&ndim, sizeof(int), 7);
// Each thread group is responsible for 1 output
size_t n_reads = REDUCE_N_READS;

View File

@@ -1,5 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
@@ -13,39 +13,63 @@ void RoPE::eval_gpu(
auto& in = inputs[0];
auto& out = outputs[0];
if (in.ndim() != 3) {
throw std::runtime_error(
"[RoPE] Only 3 dimensions are supported (batch x sequence x dims)");
if (in.ndim() < 3) {
throw std::runtime_error("[RoPE] Input must have at least 3 dimensions");
}
if (dims_ != in.shape(-1)) {
throw std::runtime_error("[RoPE] Partial RoPE application not supported");
}
if (in.flags().row_contiguous && in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
size_t strides[3];
bool donated = false;
int ndim = in.ndim();
size_t mat_size = in.shape()[ndim - 2] * in.shape()[ndim - 1];
if (in.flags().row_contiguous) {
if (in.is_donatable()) {
donated = true;
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
strides[0] = mat_size;
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (ndim == 3) {
// Handle non-contiguous 3D inputs
out.set_data(allocator::malloc_or_wait(out.nbytes()));
strides[0] = in.strides()[0];
strides[1] = in.strides()[1];
strides[2] = in.strides()[2];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
donated = true;
copy_gpu(in, out, CopyType::General, s);
strides[0] = mat_size;
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
std::ostringstream kname;
kname << "rope_" << (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index);
bool donated = in.data_shared_ptr() == nullptr;
float base = std::log2(base_);
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, donated ? out : in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(in.strides().data(), 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&offset_, sizeof(int), 3);
compute_encoder->setBytes(&base, sizeof(float), 4);
compute_encoder->setBytes(&scale_, sizeof(float), 5);
int dim0 = in.shape(2) / 2;
int dim1 = in.shape(1);
int dim2 = in.shape(0);
int dim0 = in.shape()[ndim - 1] / 2;
int dim1 = in.shape()[ndim - 2];
int dim2 = in.size() / mat_size;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder->dispatchThreads(grid_dims, group_dims);

View File

@@ -0,0 +1,222 @@
//
// scaled_dot_product_attention.cpp
// mlx
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
namespace {
void sdpa_metal(
const Stream& s,
metal::Device& d,
const array& q,
const array& k,
const array& v,
const array& p_lse,
const array& p_rowmaxes,
const array& o_partial,
const uint heads,
const uint tile_size,
const uint n_tiles,
const float alpha,
array& out,
std::vector<array>& temporaries) {
std::ostringstream kname_partials;
kname_partials << "fast_inference_sdpa_compute_partials_";
std::ostringstream kname_reduce;
std::string delimiter = "_";
kname_reduce << "fast_inference_sdpa_reduce_tiles" + delimiter;
for (const auto& arr : {k, v, out}) {
if (arr.dtype() != q.dtype()) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
}
}
if (q.dtype() == float32) {
kname_partials << "float" + delimiter;
kname_reduce << "float";
} else if (q.dtype() == float16) {
kname_partials << "half" + delimiter;
kname_reduce << "half";
} else {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
}
std::string kname_suffix_tile_size = std::to_string(tile_size) + delimiter;
uint nsimd = 8;
std::string kname_suffix_nsimdgroups = std::to_string(nsimd);
// maximum number of splits == 128 at the moment (reserved tile registers in
// reduction kernel). this is arbitrary and could be changed in the shader.
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
kname_partials << kname_suffix;
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname_partials.str());
compute_encoder->setComputePipelineState(kernel);
constexpr const uint batch = 1;
MTL::Size grid_dims = MTL::Size(heads, n_tiles, batch);
MTL::Size group_dims = MTL::Size(32, nsimd, 1);
const uint64_t KV_sequence_length = k.shape(-2);
const uint query_sequence_length = q.shape(-2);
const uint n_q_heads = q.shape(1);
const uint n_kv_heads = k.shape(1);
MLXScaledDotProductAttentionParams params{
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
set_array_buffer(compute_encoder, q, 0);
set_array_buffer(compute_encoder, k, 1);
set_array_buffer(compute_encoder, v, 2);
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
compute_encoder->setBytes(
&params, sizeof(MLXScaledDotProductAttentionParams), 4);
set_array_buffer(compute_encoder, o_partial, 5);
set_array_buffer(compute_encoder, p_lse, 6);
set_array_buffer(compute_encoder, p_rowmaxes, 7);
constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
{
auto kernel_accum = d.get_kernel(kname_reduce.str());
compute_encoder->setComputePipelineState(kernel_accum);
set_array_buffer(compute_encoder, o_partial, 0);
set_array_buffer(compute_encoder, p_lse, 1);
set_array_buffer(compute_encoder, p_rowmaxes, 2);
compute_encoder->setBytes(
&params, sizeof(MLXScaledDotProductAttentionParams), 3);
set_array_buffer(compute_encoder, out, 4);
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
compute_encoder->dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
d.get_command_buffer(s.index)->addCompletedHandler(
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
return;
}
}
} // namespace
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() >= 3);
if (!is_floating_point(out.dtype())) {
throw std::runtime_error(
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
}
if (inputs.size() == 4) {
out = fallback_(inputs)[0];
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& q_pre = inputs[0];
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> temporaries;
auto check_transpose = [&temporaries, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
temporaries.push_back(arr_copy);
size_t stx = arr.shape(-1);
return arr_copy;
}
};
auto q = check_transpose(q_pre);
auto k = check_transpose(k_pre);
auto v = check_transpose(v_pre);
const int heads = q.shape(-3);
int tile_size = 64;
const int kv_seq_len = k.shape(-2);
if (kv_seq_len > 8000) {
tile_size = 128;
}
if (kv_seq_len > 16000) {
tile_size = 256;
}
if (kv_seq_len > 32000) {
tile_size = 512;
}
const int n_tiles = (kv_seq_len + tile_size - 1) / tile_size;
array o_partials(
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles * v.shape(-1)},
float32,
nullptr,
{});
o_partials.set_data(allocator::malloc_or_wait(o_partials.nbytes()));
array p_lse(
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
array p_rowmaxes(
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
p_lse.set_data(allocator::malloc_or_wait(p_lse.nbytes()));
p_rowmaxes.set_data(allocator::malloc_or_wait(p_rowmaxes.nbytes()));
temporaries.push_back(p_lse);
temporaries.push_back(p_rowmaxes);
temporaries.push_back(o_partials);
return sdpa_metal(
s,
d,
q,
k,
v,
p_lse,
p_rowmaxes,
o_partials,
heads,
tile_size,
n_tiles,
scale_,
out,
temporaries);
}
} // namespace mlx::core::fast

View File

@@ -37,11 +37,15 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
};
const array& in = check_input(inputs[0]);
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
@@ -75,6 +79,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2);

View File

@@ -9,16 +9,43 @@ namespace mlx::core {
namespace {
void set_array_buffer(
MTL::ComputeCommandEncoder* enc,
const array& a,
int idx) {
inline void
set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
enc->setBuffer(a_buf, offset, idx);
}
inline void set_array_buffer(
MTL::ComputeCommandEncoder* enc,
const array& a,
int64_t offset,
int idx) {
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
template <typename T>
inline void set_vector_bytes(
MTL::ComputeCommandEncoder* enc,
const std::vector<T>& vec,
size_t nelems,
int idx) {
enc->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
inline void set_vector_bytes(
MTL::ComputeCommandEncoder* enc,
const std::vector<T>& vec,
int idx) {
return set_vector_bytes(enc, vec, vec.size(), idx);
}
std::string type_to_name(const array& a) {
std::string tname;
switch (a.dtype()) {
@@ -96,72 +123,6 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
}
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.
//
// When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>> strides) {
// Make a vector that has axes separated with -1. Collapse all axes between
// -1.
std::vector<int> to_collapse;
if (shape.size() > 0) {
to_collapse.push_back(0);
for (int i = 1; i < shape.size(); i++) {
bool contiguous = true;
for (const std::vector<size_t>& st : strides) {
if (st[i] * shape[i] != st[i - 1]) {
contiguous = false;
}
if (!contiguous) {
break;
}
}
if (!contiguous) {
to_collapse.push_back(-1);
}
to_collapse.push_back(i);
}
to_collapse.push_back(-1);
}
std::vector<int> out_shape;
std::vector<std::vector<size_t>> out_strides(strides.size());
for (int i = 0; i < to_collapse.size(); i++) {
int current_shape = shape[to_collapse[i]];
while (to_collapse[++i] != -1) {
current_shape *= shape[to_collapse[i]];
}
out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) {
const std::vector<size_t>& st = strides[j];
out_strides[j].push_back(st[to_collapse[i - 1]]);
}
}
return std::make_tuple(out_shape, out_strides);
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(const std::vector<array>& xs) {
std::vector<std::vector<size_t>> strides;
for (auto& x : xs) {
strides.emplace_back(x.strides());
}
return collapse_contiguous_dims(xs[0].shape(), strides);
}
template <typename... Arrays>
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(Arrays... xs) {
return collapse_contiguous_dims(
std::vector<array>{std::forward<Arrays>(xs)...});
}
} // namespace
} // namespace mlx::core

View File

@@ -23,10 +23,21 @@ std::function<void()> make_task(
"[metal::make_task] Cannot make GPU task without metal backend");
}
// No cache for CPU only
bool cache_enabled(void) {
return false;
// No-ops when Metal is not available.
size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t, bool) {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
void set_cache_enabled(bool) {}
} // namespace mlx::core::metal

View File

@@ -43,6 +43,7 @@ NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(Depends)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
NO_GPU(NumberOfElements)
NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
@@ -86,6 +87,7 @@ NO_GPU(Sign)
NO_GPU(Sin)
NO_GPU(Sinh)
NO_GPU(Slice)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU_MULTI(Split)
@@ -93,12 +95,17 @@ NO_GPU(Square)
NO_GPU(Sqrt)
NO_GPU(StopGradient)
NO_GPU(Subtract)
NO_GPU_MULTI(SVD)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU(Inverse)
namespace fast {
NO_GPU_MULTI(LayerNorm)
NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
} // namespace fast
} // namespace mlx::core

View File

@@ -7,6 +7,7 @@
#include "mlx/allocator.h"
#include "mlx/compile.h"
#include "mlx/compile_impl.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
@@ -73,7 +74,7 @@ bool allows_shapeless(const Primitive& p) {
is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) ||
typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) ||
typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) ||
typeid(p) == typeid(Select);
typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements);
}
Compiled::Compiled(
@@ -414,6 +415,11 @@ void compile_simplify(
}
tape = std::move(new_tape);
std::unordered_map<std::uintptr_t, uint32_t> tape_order;
for (uint32_t i = 0; i < tape.size(); ++i) {
tape_order.insert({tape[i].id(), i});
}
std::unordered_set<uintptr_t> output_set;
for (auto& o : outputs) {
output_set.insert(o.id());
@@ -436,16 +442,23 @@ void compile_simplify(
if (mask[j]) {
continue;
}
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
auto src_idx = j;
auto dst_idx = i;
if (tape_order[parents->second[src_idx].first.id()] <
tape_order[parents->second[dst_idx].first.id()]) {
std::swap(src_idx, dst_idx);
}
auto& src = parents->second[src_idx].first;
auto& dst = parents->second[dst_idx].first;
if (src.id() != dst.id() && array_equivalent(src, dst) &&
output_set.find(src.id()) == output_set.end()) {
merge(dst, src, parents_map);
mask[j] = true;
mask[src_idx] = true;
}
}
}
// Erase orphaned parents so we don't keep fusing with them
for (int i = N - 1; i > 0; --i) {
for (int i = N - 1; i >= 0; --i) {
if (mask[i]) {
parents->second.erase(parents->second.begin() + i);
}
@@ -455,7 +468,6 @@ void compile_simplify(
return output_set.find(a.id()) == output_set.end();
}
};
bool discard = maybe_merge_parents(arr);
for (auto& s : arr.siblings()) {
discard &= maybe_merge_parents(s);
@@ -765,7 +777,8 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
size_t fun_id,
bool shapeless /* = false */,
std::vector<uint64_t> constants /* = {} */) {
if (compile_mode() == CompileMode::disabled) {
if (compile_mode() == CompileMode::disabled ||
!(compile_available_for_device(default_device()))) {
return fun;
}
return [fun, fun_id, shapeless, constants = std::move(constants)](

11
mlx/compile_impl.h Normal file
View File

@@ -0,0 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include "mlx/device.h"
namespace mlx::core::detail {
bool compile_available_for_device(const Device& device);
}

View File

@@ -11,9 +11,9 @@ namespace mlx::core {
namespace {
static constexpr int num_types = 13;
constexpr int num_types = 13;
static constexpr Dtype::Kind type_kinds[num_types] = {
constexpr Dtype::Kind type_kinds[num_types] = {
Dtype::Kind::b, // bool_,
Dtype::Kind::u, // uint8,
Dtype::Kind::u, // uint16,
@@ -32,7 +32,7 @@ static constexpr Dtype::Kind type_kinds[num_types] = {
// Following Jax type promotion rules:
// https://jax.readthedocs.io/en/latest/type_promotion.html
// clang-format off
static constexpr Dtype type_rules[num_types][num_types] = {
constexpr Dtype type_rules[num_types][num_types] = {
// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64
{bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // bool
{uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // uint8

View File

@@ -46,26 +46,22 @@ struct Dtype {
};
};
inline bool is_available(const Dtype& dtype) {
return true;
}
inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
inline constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
inline constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
inline constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
inline constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
inline constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
inline constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
inline constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
Dtype promote_types(const Dtype& t1, const Dtype& t2);

View File

@@ -46,6 +46,141 @@ std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
return {outputs, out_axes};
}
array rms_norm(
const array& x,
const array& weight,
float eps,
StreamOrDevice s_ /* = {} */) {
if (x.ndim() == 0) {
std::ostringstream msg;
msg << "[rms_norm] Input must have at least 1 dimension but got input with "
"0 dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.ndim() != 1) {
std::ostringstream msg;
msg << "[rms_norm] weight must have 1 dimension but has " << weight.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto out_type = result_type({x, weight});
if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg;
msg << "[rms_norm] Received unsupported type " << out_type << ".";
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
auto fallback = [eps, out_type, s](const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
x = multiply(
x,
rsqrt(
add(mean(square(x, s), -1, /* keepdims */ true, s),
array(eps, float32),
s),
s),
s);
x = astype(x, out_type, s);
return std::vector<array>{multiply(inputs[1], x, s)};
};
if (s.device == Device::gpu) {
return array(
x.shape(),
out_type,
std::make_unique<RMSNorm>(s, fallback, eps),
{astype(x, out_type, s), astype(weight, out_type, s)});
}
return fallback({x, weight})[0];
}
bool RMSNorm::is_equivalent(const Primitive& other) const {
const RMSNorm& a_other = static_cast<const RMSNorm&>(other);
return eps_ == a_other.eps_;
}
array layer_norm(
const array& x,
const std::optional<array>& weight,
const std::optional<array>& bias,
float eps,
StreamOrDevice s_ /* = {} */) {
if (x.ndim() == 0) {
std::ostringstream msg;
msg << "[layer_norm] Input must have at least 1 dimension but got input with "
"0 dimensions.";
throw std::invalid_argument(msg.str());
}
if (weight.has_value() && (*weight).ndim() != 1) {
std::ostringstream msg;
msg << "[layer_norm] weight must have 1 dimension but has "
<< (*weight).ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (bias.has_value() && (*bias).ndim() != 1) {
std::ostringstream msg;
msg << "[layer_norm] bias must have 1 dimension but has " << (*bias).ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto out_type = (weight.has_value())
? ((bias.has_value()) ? result_type({x, *weight, *bias})
: result_type({x, *weight}))
: x.dtype();
if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg;
msg << "[layer_norm] Received unsupported type " << out_type << ".";
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
bool has_weight = weight.has_value();
bool has_bias = bias.has_value();
auto fallback = [has_weight, has_bias, eps, out_type, s](
const std::vector<array>& inputs) {
auto x = astype(inputs[0], float32, s);
// Should I not be smart here and leave the double mean to simplify()?
auto mu = mean(x, /* axis= */ -1, /* keepdims= */ true, s);
auto mu2 = square(mu, s);
auto x2 = mean(square(x, s), /* axis= */ -1, /* keepdims= */ true, s);
auto v = subtract(x2, mu2, s);
x = multiply(subtract(x, mu, s), rsqrt(add(v, array(eps, float32), s), s));
x = astype(x, out_type, s);
// If the LN is affine then transform x according to the weight and bias
if (has_weight) {
x = multiply(x, inputs[1], s);
}
if (has_bias) {
x = add(x, inputs[2], s);
}
return std::vector<array>{x};
};
auto passed_weight =
astype((weight.has_value()) ? *weight : array(1, out_type), out_type);
auto passed_bias =
astype((bias.has_value()) ? *bias : array(0, out_type), out_type);
if (s.device == Device::gpu) {
return array(
x.shape(),
out_type,
std::make_unique<LayerNorm>(s, fallback, eps),
{astype(x, out_type, s), passed_weight, passed_bias});
}
return fallback({x, passed_weight, passed_bias})[0];
}
bool LayerNorm::is_equivalent(const Primitive& other) const {
const LayerNorm& a_other = static_cast<const LayerNorm&>(other);
return eps_ == a_other.eps_;
}
array rope(
const array& x,
int dims,
@@ -54,10 +189,10 @@ array rope(
float scale,
int offset,
StreamOrDevice s /* = {} */) {
if (x.ndim() != 3) {
if (x.ndim() < 3) {
std::ostringstream msg;
msg << "[rope] Input must have 3 dimensions but got input with " << x.ndim()
<< " dimensions.";
msg << "[rope] Input must have at least 3 dimensions but got input with "
<< x.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (traditional && x.shape(-1) != dims) {
@@ -67,7 +202,9 @@ array rope(
auto fallback = [dims, traditional, base, scale, offset, s](
const std::vector<array>& inputs) {
auto& x = inputs[0];
auto& shape = inputs[0].shape();
int ndim = shape.size();
auto x = reshape(inputs[0], {-1, shape[ndim - 2], shape[ndim - 1]}, s);
auto t = x.dtype();
auto N = x.shape(1) + offset;
// Compute sines and cosines
@@ -89,7 +226,7 @@ array rope(
for (auto& o : outs) {
o = expand_dims(o, 3, s);
}
return std::vector<array>{reshape(concatenate(outs, 3, s), x.shape(), s)};
return std::vector<array>{reshape(concatenate(outs, 3, s), shape, s)};
} else {
auto out_s = x.shape();
out_s.back() = half_dims;
@@ -103,10 +240,9 @@ array rope(
if (dims < x.shape(-1)) {
outs.push_back(slice(x, {0, 0, dims}, x.shape(), s));
}
return std::vector<array>{concatenate(outs, 2, s)};
return std::vector<array>{reshape(concatenate(outs, 2, s), shape, s)};
}
};
// TODO change to condition for using custom prim
auto stream = to_stream(s);
if (stream.device == Device::gpu && x.shape(-1) == dims) {
return array(
@@ -127,4 +263,144 @@ bool RoPE::is_equivalent(const Primitive& other) const {
offset_ == a_other.offset_);
}
/** Computes: O = softmax(Q @ K.T) @ V **/
array scaled_dot_product_attention(
const array& queries,
const array& keys,
const array& values,
const float scale,
const std::optional<array>& mask,
StreamOrDevice s) {
for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] input with shape "
<< tensor.shape() << " expected to be rank 4";
throw std::invalid_argument(msg.str());
}
}
const size_t batch_dim = queries.shape(0);
for (const auto& tensor : {keys, values}) {
if (tensor.shape(0) != batch_dim) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] mismatching batch dimension for input with shape "
<< tensor.shape() << ".";
throw std::invalid_argument(msg.str());
}
}
// Q, K must have matching last dims (d_k aka 'head_dim');
if (queries.shape(-1) != keys.shape(-1)) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] query, keys expected to have matching last dimension; found query shape "
<< queries.shape() << " for keys shape " << keys.shape() << ".";
throw std::invalid_argument(msg.str());
}
// K, V must have matching number of heads (n_kv_heads);
auto n_q_heads = queries.shape(-3);
auto n_kv_heads = keys.shape(-3);
if (keys.shape(-3) != values.shape(-3)) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] keys, values expected to have matching n_kv_heads; found keys with n_heads "
<< keys.shape(-3) << " for values with n_heads " << values.shape(-3)
<< ".";
throw std::invalid_argument(msg.str());
}
// n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
if (n_q_heads % n_kv_heads != 0) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] n_heads must be a multiple of n_kv_heads, found n_heads "
<< n_q_heads << " for n_kv_heads " << n_kv_heads << ".";
throw std::invalid_argument(msg.str());
}
auto final_type = result_type({queries, keys, values});
if (!is_floating_point(final_type) || is_complex(final_type)) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received unsupported type "
<< final_type << ".";
throw std::invalid_argument(msg.str());
}
auto q = astype(queries, final_type, s);
auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s);
auto out_shape =
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
/* generic implementation for use cases that Metal implementation does not
* support. For non-supported cases listed below, use MLX primitives:
* * CPU implementation
* * batch size > 1
* * query sequence length > 1
* * non-null mask
* * dtype is not fp32 or fp16
*/
bool needs_mask = mask.has_value();
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
const std::vector<array>& inputs) {
auto q = multiply(array(scale, inputs[0].dtype()), inputs[0], s);
int n_repeats = n_q_heads / n_kv_heads;
int B = q.shape(0);
int L = q.shape(2);
auto k = inputs[1];
auto v = inputs[2];
if (n_repeats > 1) {
q = reshape(q, {B, n_kv_heads, n_repeats, L, -1}, s);
k = expand_dims(k, 2, s);
v = expand_dims(v, 2, s);
}
auto scores = matmul(q, swapaxes(k, -1, -2, s), s);
if (needs_mask) {
scores = add(scores, inputs[3], s);
}
scores = astype(
softmax(astype(scores, float32, s), std::vector<int>{-1}, s),
final_type,
s);
auto out = matmul(scores, v, s);
if (n_repeats > 1) {
out = reshape(out, {B, n_q_heads, L, -1}, s);
}
return std::vector<array>{out};
};
auto stream = to_stream(s);
constexpr const int supported_head_dim = 128;
const size_t query_head_dim = q.shape(-1);
const size_t query_sequence_length = q.shape(2);
bool implementation_supports_use_case = batch_dim == 1 &&
query_sequence_length == 1 && !mask.has_value() &&
query_head_dim == supported_head_dim && final_type != bfloat16 &&
stream.device == Device::gpu;
// TODO, update routing conditions post further tuning
implementation_supports_use_case &= false;
if (implementation_supports_use_case) {
auto out = array(
out_shape,
final_type,
std::make_unique<ScaledDotProductAttention>(
stream, fallback, scale, false),
{q, k, v});
return out;
}
if (mask.has_value()) {
return fallback({q, k, v, mask.value()})[0];
} else {
return fallback({q, k, v})[0];
}
}
bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
const ScaledDotProductAttention& a_other =
static_cast<const ScaledDotProductAttention&>(other);
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
}
} // namespace mlx::core::fast

View File

@@ -2,10 +2,25 @@
#pragma once
#include <optional>
#include "mlx/utils.h"
namespace mlx::core::fast {
array rms_norm(
const array& x,
const array& weight,
float eps,
StreamOrDevice s = {});
array layer_norm(
const array& x,
const std::optional<array>& weight,
const std::optional<array>& bias,
float eps,
StreamOrDevice s = {});
array rope(
const array& x,
int dims,
@@ -13,6 +28,15 @@ array rope(
float base,
float scale,
int offset,
StreamOrDevice s /* = {} */);
StreamOrDevice s = {});
/** Computes: O = softmax(Q @ K.T) @ V **/
array scaled_dot_product_attention(
const array& queries,
const array& keys,
const array& values,
const float scale,
const std::optional<array>& mask = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::fast

View File

@@ -1,3 +1,5 @@
// Copyright © 2024 Apple Inc.
#include "mlx/primitives.h"
namespace mlx::core::fast {
@@ -31,6 +33,52 @@ class Custom : public Primitive {
std::function<std::vector<array>(std::vector<array>)> fallback_;
};
class RMSNorm : public Custom {
public:
RMSNorm(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(RMSNorm)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class LayerNorm : public Custom {
public:
LayerNorm(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
float eps)
: Custom(stream, fallback), eps_(eps){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(LayerNorm)
bool is_equivalent(const Primitive& other) const override;
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float eps_;
};
class RoPE : public Custom {
public:
RoPE(
@@ -49,7 +97,9 @@ class RoPE : public Custom {
offset_(offset){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
@@ -65,4 +115,34 @@ class RoPE : public Custom {
int offset_;
};
class ScaledDotProductAttention : public Custom {
public:
explicit ScaledDotProductAttention(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
const float scale,
const bool needs_mask)
: Custom(stream, fallback), scale_(scale), needs_mask_(needs_mask){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
outputs[0] = fallback_(inputs)[0];
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
eval_gpu(inputs, outputs[0]);
};
void eval_gpu(const std::vector<array>& inputs, array& out);
bool is_equivalent(const Primitive& other) const override;
DEFINE_PRINT(ScaledDotProductAttention)
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
float scale_;
bool needs_mask_;
};
} // namespace mlx::core::fast

View File

@@ -1,5 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cstring>
#include <fstream>
@@ -18,7 +17,7 @@ namespace mlx::core {
namespace {
static constexpr uint8_t MAGIC[] = {
constexpr uint8_t MAGIC[] = {
0x93,
0x4e,
0x55,
@@ -122,8 +121,6 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());
out_stream->write(header.str().c_str(), header.str().length());
out_stream->write(a.data<char>(), a.nbytes());
return;
}
/** Save array to file in .npy format */

View File

@@ -200,4 +200,65 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
return std::make_pair(out[0], out[1]);
}
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::svd] Input array must have type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
const auto m = a.shape(-2);
const auto n = a.shape(-1);
const auto rank = a.ndim();
std::vector<int> u_shape = a.shape();
u_shape[rank - 2] = m;
u_shape[rank - 1] = m;
std::vector<int> s_shape = a.shape();
s_shape.pop_back();
s_shape[rank - 2] = std::min(m, n);
std::vector<int> vt_shape = a.shape();
vt_shape[rank - 2] = n;
vt_shape[rank - 1] = n;
return array::make_arrays(
{u_shape, s_shape, vt_shape},
{a.dtype(), a.dtype(), a.dtype()},
std::make_unique<SVD>(to_stream(s)),
{a});
}
array inv(const array& a, StreamOrDevice s /* = {} */) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::inv] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(
"[linalg::inv] Inverses are only defined for square matrices.");
}
return array(
a.shape(), a.dtype(), std::make_unique<Inverse>(to_stream(s)), {a});
}
} // namespace mlx::core::linalg

View File

@@ -62,4 +62,8 @@ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
std::vector<array> svd(const array& a, StreamOrDevice s = {});
array inv(const array& a, StreamOrDevice s = {});
} // namespace mlx::core::linalg

View File

@@ -46,14 +46,6 @@ std::pair<std::vector<int>, std::vector<int>> compute_reduce_shape(
return {out_shape, sorted_axes};
}
int compute_number_of_elements(const array& a, const std::vector<int>& axes) {
int nelements = 1;
for (auto axis : axes) {
nelements *= a.shape(axis);
}
return nelements;
}
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
}
@@ -293,26 +285,35 @@ array reshape(
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] == -1) {
if (infer_idx >= 0) {
throw std::invalid_argument("Reshape can only infer one dimension.");
throw std::invalid_argument(
"[reshape] Reshape can only infer one dimension.");
}
infer_idx = i;
} else {
size *= shape[i];
}
}
// Infer the shape
if (size > 0) {
auto q_and_r = std::ldiv(a.size(), size);
if (infer_idx >= 0) {
shape[infer_idx] = q_and_r.quot;
size *= q_and_r.quot;
}
} else if (infer_idx >= 0) {
throw std::invalid_argument(
"[reshape] Cannot infer the shape of an empty array");
}
// Check the the reshaping is valid
if (a.size() != size) {
std::ostringstream msg;
msg << "Cannot reshape array of size " << a.size() << " into shape "
<< shape << ".";
msg << "[reshape] Cannot reshape array of size " << a.size()
<< " into shape " << shape << ".";
throw std::invalid_argument(msg.str());
}
return array(
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
}
@@ -444,6 +445,60 @@ array expand_dims(
return reshape(a, out_shape, s);
}
// Slice helper
namespace {
inline auto normalize_slice(
const std::vector<int>& shape,
std::vector<int>& start,
std::vector<int>& stop,
std::vector<int>& strides) {
std::vector<int> out_shape(shape.size());
bool has_neg_strides = false;
for (int i = 0; i < shape.size(); ++i) {
// Following numpy docs
// Negative i and j are interpreted as n + i and n + j where n is
// the number of elements in the corresponding dimension. Negative
// k makes stepping go towards smaller indices
auto n = shape[i];
auto s = start[i];
s = s < 0 ? s + n : s;
auto e = stop[i];
e = e < 0 ? e + n : e;
// Note: -ve strides require start >= stop
if (strides[i] < 0) {
has_neg_strides = true;
// Clamp to bounds
auto st = std::min(s, n - 1);
auto ed = std::max(-1, e);
start[i] = st;
stop[i] = ed > st ? st : ed;
auto str = -strides[i];
out_shape[i] = (start[i] - stop[i] + str - 1) / str;
} else {
// Clamp to bounds
auto st = std::max(0, std::min(s, n));
auto ed = std::max(0, std::min(e, n));
start[i] = st;
stop[i] = ed < st ? st : ed;
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
}
}
return std::make_pair(has_neg_strides, out_shape);
}
} // namespace
array slice(
const array& a,
std::vector<int> start,
@@ -458,113 +513,13 @@ array slice(
throw std::invalid_argument(msg.str());
}
std::vector<int> negatively_strided_axes;
std::vector<std::vector<int>> negatively_strided_slices;
std::vector<int> out_shape(a.ndim());
for (int i = 0; i < a.ndim(); ++i) {
// Following numpy docs
// Negative i and j are interpreted as n + i and n + j where n is
// the number of elements in the corresponding dimension. Negative
// k makes stepping go towards smaller indices
auto [has_neg_strides, out_shape] =
normalize_slice(a.shape(), start, stop, strides);
auto n = a.shape(i);
auto s = start[i];
s = s < 0 ? s + n : s;
auto e = stop[i];
e = e < 0 ? e + n : e;
// Note: We pass positive strides to the primitive and then flip
// the axes later as needed
if (strides[i] < 0) {
negatively_strided_axes.push_back(i);
auto st = std::min(s, n - 1);
auto ed = std::max(e, -1);
negatively_strided_slices.push_back({st, ed, strides[i]});
start[i] = 0;
stop[i] = n;
strides[i] = 1;
} else {
start[i] = s;
stop[i] = e < s ? s : e;
}
// Clamp to bounds
start[i] = std::max(0, std::min(start[i], n));
stop[i] = std::max(0, std::min(stop[i], n));
out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i];
}
// If strides are negative, slice and then make a copy with axes flipped
if (negatively_strided_axes.size() > 0) {
// First, take the slice of the positively strided axes
auto out = array(
out_shape,
a.dtype(),
std::make_unique<Slice>(
to_stream(s),
std::move(start),
std::move(stop),
std::move(strides)),
{a});
std::vector<array> indices;
std::vector<int> slice_sizes = out.shape();
std::vector<int> t_axes(out.ndim(), -1);
std::vector<int> out_reshape(out.ndim(), -1);
int n_axes = negatively_strided_axes.size();
for (int i = 0; i < n_axes; i++) {
// Get axis and corresponding slice
auto ax = negatively_strided_axes[i];
auto sl = negatively_strided_slices[i];
// Get indices for the slice
auto ax_idx = arange(sl[0], sl[1], sl[2], s);
// Reshape indices for broadcast as needed
std::vector<int> ax_idx_shape(n_axes, 1);
ax_idx_shape[i] = ax_idx.size();
ax_idx = reshape(ax_idx, ax_idx_shape, s);
// Add indices to list
indices.push_back(ax_idx);
// Set slice size for axis
slice_sizes[ax] = 1;
// Gather moves the axis up, remainder needs to be squeezed
out_reshape[i] = indices[i].size();
// Gather moves the axis up, needs to be transposed
t_axes[ax] = i;
}
// Prepare out_reshape to squeeze gathered dims
// Prepare to transpose dims as needed
int j = n_axes;
for (int i = 0; j < out.ndim() && i < out.ndim(); i++) {
if (t_axes[i] < 0) {
t_axes[i] = j;
out_reshape[j] = out_shape[i];
j++;
}
}
// Gather
out = gather(out, indices, negatively_strided_axes, slice_sizes, s);
// Squeeze dims
out = reshape(out, out_reshape, s);
// Transpose dims
out = transpose(out, t_axes, s);
return out;
}
if (out_shape == a.shape()) {
if (!has_neg_strides && out_shape == a.shape()) {
return a;
}
return array(
out_shape,
a.dtype(),
@@ -581,6 +536,55 @@ array slice(
return slice(a, start, stop, std::vector<int>(a.ndim(), 1), to_stream(s));
}
/** Update a slice from the source array */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
StreamOrDevice s /* = {} */) {
// Check dimensions
if (start.size() != src.ndim() || stop.size() != src.ndim() ||
strides.size() != src.ndim()) {
std::ostringstream msg;
msg << "[slice] Invalid number of indices or strides for "
<< "array with dimension " << src.ndim() << ".";
throw std::invalid_argument(msg.str());
}
// Process slice dimensions
auto [has_neg_strides, upd_shape] =
normalize_slice(src.shape(), start, stop, strides);
// Broadcast update shape to slice shape
auto update_broadcasted = broadcast_to(update, upd_shape, s);
// If the entire src is the slice, just return the update
if (!has_neg_strides && upd_shape == src.shape()) {
return astype(update_broadcasted, src.dtype(), s);
}
return array(
src.shape(),
src.dtype(),
std::make_unique<SliceUpdate>(
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
{src, update_broadcasted});
}
/** Update a slice from the source array with stride 1 in each dimension */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s /* = {} */) {
auto strides = std::vector<int>(src.ndim(), 1);
return slice_update(
src, update, std::move(start), std::move(stop), std::move(strides), s);
}
std::vector<array> split(
const array& a,
const std::vector<int>& indices,
@@ -1347,9 +1351,9 @@ array mean(
throw std::invalid_argument(msg.str());
}
}
auto nelements = compute_number_of_elements(a, axes);
auto dtype = at_least_float(a.dtype());
return multiply(sum(a, axes, keepdims, s), array(1.0 / nelements, dtype), s);
auto normalizer = number_of_elements(a, axes, true, dtype, s);
return multiply(sum(a, axes, keepdims, s), normalizer, s);
}
array mean(
@@ -1382,9 +1386,12 @@ array var(
auto v = subtract(a2, mu2, s);
if (ddof != 0) {
auto nelements = compute_number_of_elements(a, axes);
auto factor = nelements / static_cast<float>(std::max(nelements - ddof, 0));
v = multiply(v, array(factor, dtype), s);
auto nelements = number_of_elements(a, axes, false, dtype, s);
auto factor = divide(
nelements,
maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s),
s);
v = multiply(v, factor, s);
}
return v;
@@ -1509,9 +1516,11 @@ array min(
array argmin(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size();
auto result = argmin(reshape(a, {size}, s), 0, false, s);
auto result = argmin(reshape(a, {size}, s), 0, true, s);
if (keepdims) {
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
} else {
result = squeeze(result, s);
}
return result;
}
@@ -1540,9 +1549,11 @@ array argmin(
array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) {
int size = a.size();
auto result = argmax(reshape(a, {size}, s), 0, false, s);
auto result = argmax(reshape(a, {size}, s), 0, true, s);
if (keepdims) {
result = reshape(result, std::vector<int>(a.shape().size(), 1), s);
} else {
result = squeeze(result, s);
}
return result;
}
@@ -1700,7 +1711,7 @@ array argpartition(
int kth_ = kth < 0 ? kth + a.shape(axis) : kth;
if (kth_ < 0 || kth_ >= a.shape(axis_)) {
std::ostringstream msg;
msg << "[argpartition] Received invalid kth " << kth << "along axis "
msg << "[argpartition] Received invalid kth " << kth << " along axis "
<< axis << " for array with shape: " << a.shape();
throw std::invalid_argument(msg.str());
}
@@ -1721,24 +1732,28 @@ array topk(const array& a, int k, StreamOrDevice s /* = {}*/) {
array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) {
// Check for valid axis
int axis_ = axis < 0 ? axis + a.ndim() : axis;
int kth_ = k < 0 ? k + a.shape(axis) : k;
if (axis_ < 0 || axis_ >= static_cast<int>(a.ndim())) {
std::ostringstream msg;
msg << "[topk] Received invalid axis " << axis << " for array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (kth_ < 0 || kth_ >= a.shape(axis_)) {
if (k < 0 || k > a.shape(axis_)) {
std::ostringstream msg;
msg << "[topk] Received invalid k " << k << "along axis " << axis
msg << "[topk] Received invalid k=" << k << " along axis " << axis
<< " for array with shape: " << a.shape();
throw std::invalid_argument(msg.str());
}
array a_partitioned = partition(a, kth_, axis_, s);
// Return early if the whole input was requested.
if (k == a.shape(axis_)) {
return a;
}
array a_partitioned = partition(a, -k, axis_, s);
std::vector<int> slice_starts(a.ndim(), 0);
std::vector<int> slice_ends = a.shape();
slice_starts[axis_] = kth_;
slice_starts[axis_] = a.shape(axis_) - k;
return slice(a_partitioned, slice_starts, slice_ends, s);
}
@@ -1753,7 +1768,7 @@ array logsumexp(
const std::vector<int>& axes,
bool keepdims /* = false */,
StreamOrDevice s /* = {}*/) {
auto maxval = stop_gradient(max(a, axes, true, s));
auto maxval = stop_gradient(max(a, axes, true, s), s);
auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s);
out = add(out, reshape(maxval, out.shape(), s), s);
if (!keepdims) {
@@ -3028,6 +3043,15 @@ array quantized_matmul(
}
auto dtype = result_type({x, scales, biases});
if (!is_floating_point(dtype) || is_complex(dtype)) {
std::ostringstream msg;
msg << "[quantized_matmul] Only real floating types are supported but "
<< "the passed types where x.dtype() == " << x.dtype()
<< ", scales.dtype() == " << scales.dtype()
<< " and biases.dtype() == " << biases.dtype();
throw std::invalid_argument(msg.str());
}
auto out = array(
{x.shape(0), w_outer_dims},
dtype,
@@ -3186,42 +3210,41 @@ array dequantize(
array tensordot(
const array& a,
const array& b,
const int dims /* = 2 */,
const int axis /* = 2 */,
StreamOrDevice s /* = {} */
) {
if (dims < 0) {
if (axis < 0) {
throw std::invalid_argument(
"[tensordot] dims must be greater or equal to 0.");
"[tensordot] axis must be greater or equal to 0.");
}
if (dims > std::min(a.ndim(), b.ndim())) {
if (axis > std::min(a.ndim(), b.ndim())) {
throw std::invalid_argument(
"[tensordot] dims must be less than the number of dimensions of a and b.");
"[tensordot] axis must be less than the number of dimensions of a and b.");
}
std::vector<int> adims;
std::vector<int> bdims;
for (int i = 0; i < dims; i++) {
for (int i = 0; i < axis; i++) {
bdims.emplace_back(i);
adims.emplace_back(-dims + i);
adims.emplace_back(i - axis);
}
return tensordot(a, b, {adims, bdims}, s);
return tensordot(a, b, {adims}, {bdims}, s);
}
array tensordot(
const array& a,
const array& b,
const std::pair<std::vector<int>, std::vector<int>>& dims,
StreamOrDevice s /* = {} */
) {
if (dims.first.size() != dims.second.size()) {
throw std::invalid_argument(
"[tensordot] dims[0] and dims[1] must have the same number of dimensions.");
const std::vector<int>& axes_a,
const std::vector<int>& axes_b,
StreamOrDevice s /* = {} */) {
if (axes_a.size() != axes_b.size()) {
throw std::invalid_argument("[tensordot] axes must have the same size.");
}
int csize = 1;
auto x = a;
auto y = b;
for (int i = 0; i < dims.first.size(); i++) {
if (x.shape(dims.first.at(i)) == y.shape(dims.second.at(i))) {
csize *= x.shape(dims.first.at(i));
for (int i = 0; i < axes_a.size(); i++) {
if (x.shape(axes_a.at(i)) == y.shape(axes_b.at(i))) {
csize *= x.shape(axes_a.at(i));
} else {
throw std::invalid_argument(
"[tensordot] a and b must have the same shape on the contracted axes.");
@@ -3230,11 +3253,11 @@ array tensordot(
std::vector<bool> cdims1(x.ndim(), false);
std::vector<bool> cdims2(y.ndim(), false);
for (const auto n : dims.first) {
for (const auto n : axes_a) {
int n_ = (n < 0) ? n + x.ndim() : n;
cdims1[n_] = true;
}
for (const auto n : dims.second) {
for (const auto n : axes_b) {
int n_ = (n < 0) ? n + y.ndim() : n;
cdims2[n_] = true;
}
@@ -3251,10 +3274,10 @@ array tensordot(
rshape.emplace_back(a.shape(i));
}
}
for (const auto x : dims.first) {
for (const auto x : axes_a) {
t1.emplace_back(x);
}
for (const auto x : dims.second) {
for (const auto x : axes_b) {
t2.emplace_back(x);
}
for (int i = 0; i < b.ndim(); i++) {
@@ -3283,7 +3306,7 @@ array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) {
"[inner] a and b must have the same last dimension.");
}
return tensordot(a, b, {{-1}, {-1}}, s);
return tensordot(a, b, {-1}, {-1}, s);
}
/** Compute D = beta * C + alpha * (A @ B) */
@@ -3294,17 +3317,8 @@ array addmm(
const float& alpha /* = 1.f */,
const float& beta /* = 1.f */,
StreamOrDevice s /* = {} */) {
// Divert in the case of vector-matrix multiplication
// TODO: Add the needed specializtion
if (a.ndim() == 1 || b.ndim() == 1) {
array X = matmul(a, b, s);
array alpha_arr = array(alpha, X.dtype());
array aX = multiply(alpha_arr, X, s);
array beta_arr = array(beta, c.dtype());
array bY = multiply(beta_arr, c, s);
return add(aX, bY, s);
}
int in_a_ndim = a.ndim();
int in_b_ndim = b.ndim();
if (a.ndim() == 0 || b.ndim() == 0) {
throw std::invalid_argument(
@@ -3312,6 +3326,15 @@ array addmm(
"have at least one dimension.");
}
if (a.ndim() == 1) {
// Insert a singleton dim in the beginning
a = reshape(a, {1, -1}, s);
}
if (b.ndim() == 1) {
// Insert a singleton dim at the end
b = reshape(b, {-1, 1}, s);
}
if (a.shape(-1) != b.shape(-2)) {
std::ostringstream msg;
msg << "[addmm] Last dimension of first input with shape " << a.shape()
@@ -3340,7 +3363,13 @@ array addmm(
std::vector<int> out_shape = a.shape();
a = reshape(a, {-1, out_shape.back()}, s);
out_shape.back() = b.shape(-1);
if (in_b_ndim == 1) {
out_shape.pop_back();
}
c = broadcast_to(c, {a.shape(0), b.shape(1)}, s);
auto out = array(
{a.shape(0), b.shape(1)},
out_type,
@@ -3368,15 +3397,42 @@ array addmm(
auto out_shape = a.shape();
out_shape.back() = b.shape(-1);
auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape);
auto out_shape_adjusted = out_shape;
if (in_a_ndim == 1 || in_b_ndim == 1) {
out_shape_adjusted.erase(
out_shape_adjusted.end() - ((in_a_ndim == 1) ? 2 : 1),
out_shape_adjusted.end() - ((in_b_ndim == 1) ? 0 : 1));
}
auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape_adjusted);
c = broadcast_to(c, c_broadcast_shape, s);
if (in_a_ndim == 1 || in_b_ndim == 1) {
auto c_reshape = c.shape();
if (in_b_ndim == 1) {
c_reshape.push_back(1);
}
if (in_a_ndim == 1) {
c_reshape.push_back(c_reshape.back());
c_reshape[c_reshape.size() - 2] = 1;
}
c = reshape(c, c_reshape, s);
}
auto out = array(
out_shape,
out_type,
std::make_unique<AddMM>(to_stream(s), alpha, beta),
{a, b, c});
// Remove the possibly inserted singleton dimensions
if (in_a_ndim == 1 || in_b_ndim == 1) {
out = reshape(out, out_shape_adjusted, s);
}
return out;
}
@@ -3542,4 +3598,29 @@ std::vector<array> atleast_3d(
return out;
}
array number_of_elements(
const array& a,
std::vector<int> axes,
bool inverted,
Dtype dtype /* = int32 */,
StreamOrDevice s /* = {} */) {
for (auto& ax : axes) {
int normal_axis = (ax + a.ndim()) % a.ndim();
if (normal_axis >= a.ndim() || normal_axis < 0) {
std::ostringstream msg;
msg << "[number_of_elements] Can't get the shape for axis " << ax
<< " from an array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
ax = normal_axis;
}
return stop_gradient(array(
std::vector<int>{},
dtype,
std::make_unique<NumberOfElements>(
to_stream(s), std::move(axes), inverted, dtype),
{a}));
}
} // namespace mlx::core

View File

@@ -177,6 +177,23 @@ array slice(
const std::vector<int>& stop,
StreamOrDevice s = {});
/** Update a slice from the source array */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
std::vector<int> strides,
StreamOrDevice s = {});
/** Update a slice from the source array with stride 1 in each dimension */
array slice_update(
const array& src,
const array& update,
std::vector<int> start,
std::vector<int> stop,
StreamOrDevice s = {});
/** Split an array into sub-arrays along a given axis. */
std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
@@ -1110,17 +1127,18 @@ array dequantize(
int bits = 4,
StreamOrDevice s = {});
/** TensorDot returns a contraction of a and b over multiple dimensions. */
/** Returns a contraction of a and b over multiple dimensions. */
array tensordot(
const array& a,
const array& b,
const int dims = 2,
const int axis = 2,
StreamOrDevice s = {});
array tensordot(
const array& a,
const array& b,
const std::pair<std::vector<int>, std::vector<int>>& dims,
const std::vector<int>& axes_a,
const std::vector<int>& axes_b,
StreamOrDevice s = {});
/** Compute the outer product of two vectors. */
@@ -1172,4 +1190,15 @@ std::vector<array> atleast_3d(
const std::vector<array>& a,
StreamOrDevice s = {});
/**
* Extract the number of elements along some axes as a scalar array. Used to
* allow shape dependent shapeless compilation (pun intended).
*/
array number_of_elements(
const array& a,
std::vector<int> axes,
bool inverted,
Dtype dtype = int32,
StreamOrDevice s = {});
} // namespace mlx::core

View File

@@ -23,6 +23,10 @@ std::tuple<array, array, int> vmap_binary_op(
assert(inputs.size() == 2);
assert(axes.size() == 2);
if (axes[0] == -1 && axes[1] == -1) {
return {inputs[0], inputs[1], -1};
}
auto a = inputs[0];
auto b = inputs[1];
int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1));
@@ -55,6 +59,10 @@ std::tuple<array, array, array, int> vmap_ternary_op(
assert(inputs.size() == 3);
assert(axes.size() == 3);
if (axes[0] == -1 && axes[1] == -1 && axes[2] == -1) {
return {inputs[0], inputs[1], inputs[2], -1};
}
auto a = inputs[0];
auto b = inputs[1];
auto c = inputs[2];
@@ -102,7 +110,11 @@ std::vector<array> Primitive::jvp(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&) {
throw std::invalid_argument("Primitive's jvp not implemented.");
std::ostringstream msg;
msg << "[Primitive::jvp] Not implemented for ";
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
std::vector<array> Primitive::vjp(
@@ -110,13 +122,21 @@ std::vector<array> Primitive::vjp(
const std::vector<array>&,
const std::vector<int>&,
const std::vector<array>&) {
throw std::invalid_argument("Primitive's vjp not implemented.");
std::ostringstream msg;
msg << "[Primitive::vip] Not implemented for ";
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
std::pair<std::vector<array>, std::vector<int>> Primitive::vmap(
const std::vector<array>&,
const std::vector<int>&) {
throw std::invalid_argument("Primitive's vmap not implemented.");
std::ostringstream msg;
msg << "[Primitive::vmap] Not implemented for ";
print(msg);
msg << ".";
throw std::invalid_argument(msg.str());
};
std::vector<std::vector<int>> Primitive::output_shapes(
@@ -227,6 +247,18 @@ bool AddMM::is_equivalent(const Primitive& other) const {
return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_);
}
std::pair<std::vector<array>, std::vector<int>> AddMM::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto maybe_move_ax = [this](auto& arr, auto ax) {
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
};
auto a = maybe_move_ax(inputs[0], axes[0]);
auto b = maybe_move_ax(inputs[1], axes[1]);
auto c = maybe_move_ax(inputs[2], axes[2]);
return {{addmm(c, a, b, alpha_, beta_, stream())}, {0}};
}
bool Arange::is_equivalent(const Primitive& other) const {
const Arange& a_other = static_cast<const Arange&>(other);
return (
@@ -403,8 +435,8 @@ std::pair<std::vector<array>, std::vector<int>> ArgPartition::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {
{argpartition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes};
}
bool ArgPartition::is_equivalent(const Primitive& other) const {
@@ -420,7 +452,7 @@ bool ArgReduce::is_equivalent(const Primitive& other) const {
std::pair<std::vector<array>, std::vector<int>> ArgReduce::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
int reduce_ax = axis_ + (axis_ >= axes[0]);
int reduce_ax = axis_ + (axes[0] >= 0 && axis_ >= axes[0]);
auto& in = inputs[0];
std::vector<array> out;
if (reduce_type_ == ArgReduce::ArgMin) {
@@ -437,7 +469,8 @@ std::pair<std::vector<array>, std::vector<int>> ArgSort::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes};
}
std::vector<std::vector<int>> ArgReduce::output_shapes(
@@ -563,13 +596,16 @@ std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
auto ax = axes[0];
auto in_shape = inputs[0].shape();
int diff = shape_.size() - inputs[0].ndim() + 1;
assert(diff >= 0);
in_shape.insert(in_shape.begin(), diff, 1);
ax += diff;
shape_.insert(shape_.begin() + ax, in_shape[ax]);
auto in = reshape(inputs[0], in_shape, stream());
auto in = inputs[0];
if (ax >= 0) {
auto in_shape = in.shape();
int diff = shape_.size() - in.ndim() + 1;
assert(diff >= 0);
in_shape.insert(in_shape.begin(), diff, 1);
ax += diff;
shape_.insert(shape_.begin() + ax, in_shape[ax]);
in = reshape(in, in_shape, stream());
}
return {{broadcast_to(in, shape_, stream())}, {ax}};
}
@@ -653,24 +689,30 @@ std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
std::vector<array> t_inputs;
int out_ax = -1;
// Find the first vmapped input
int i = 0;
for (; i < axes.size(); i++) {
t_inputs.push_back(inputs[i]);
if (axes[i] >= 0) {
out_ax = axes[i];
break;
}
}
auto out_ax = axes[i++];
// Move vmap axes to the same spot.
for (; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
} else {
t_inputs.push_back(inputs[i]);
if (out_ax >= 0) {
// Advance to the next input
i++;
// Move vmap axes to the same spot.
for (; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
} else {
t_inputs.push_back(inputs[i]);
}
}
}
auto axis = axis_ + (axis_ >= out_ax);
auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);
return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
}
@@ -1085,7 +1127,7 @@ std::pair<std::vector<array>, std::vector<int>> Equal::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {{equal(a, b, stream())}, axes};
return {{equal(a, b, stream())}, {to_ax}};
}
std::vector<array> Equal::vjp(
@@ -1210,13 +1252,15 @@ std::pair<std::vector<array>, std::vector<int>> FFT::vmap(
int ax = axes[0];
auto fft_axes = axes_;
auto out_shape = in.shape();
for (auto& fft_ax : fft_axes) {
if (fft_ax >= ax) {
fft_ax++;
}
if (real_) {
auto n = out_shape[fft_ax];
out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;
if (ax >= 0) {
for (auto& fft_ax : fft_axes) {
if (fft_ax >= ax) {
fft_ax++;
}
if (real_) {
auto n = out_shape[fft_ax];
out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1;
}
}
}
return {
@@ -1424,7 +1468,7 @@ std::pair<std::vector<array>, std::vector<int>> Greater::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {{greater(a, b, stream())}, axes};
return {{greater(a, b, stream())}, {to_ax}};
}
std::vector<array> Greater::vjp(
@@ -1451,7 +1495,7 @@ std::pair<std::vector<array>, std::vector<int>> GreaterEqual::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {{greater_equal(a, b, stream())}, axes};
return {{greater_equal(a, b, stream())}, {to_ax}};
}
std::vector<array> GreaterEqual::vjp(
@@ -1478,7 +1522,7 @@ std::pair<std::vector<array>, std::vector<int>> Less::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {{less(a, b, stream())}, axes};
return {{less(a, b, stream())}, {to_ax}};
}
std::vector<array> Less::vjp(
@@ -1505,7 +1549,7 @@ std::pair<std::vector<array>, std::vector<int>> LessEqual::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream());
return {{less_equal(a, b, stream())}, axes};
return {{less_equal(a, b, stream())}, {to_ax}};
}
std::vector<array> LessEqual::vjp(
@@ -1752,6 +1796,17 @@ std::vector<array> Matmul::vjp(
return vjps;
}
std::pair<std::vector<array>, std::vector<int>> Matmul::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto maybe_move_ax = [this](auto& arr, auto ax) {
return ax > 0 ? moveaxis(arr, ax, 0, stream()) : arr;
};
auto a = maybe_move_ax(inputs[0], axes[0]);
auto b = maybe_move_ax(inputs[1], axes[1]);
return {{matmul(a, b, stream())}, {0}};
}
std::vector<array> Maximum::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
@@ -2064,7 +2119,8 @@ std::pair<std::vector<array>, std::vector<int>> Partition::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{partition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{partition(inputs[0], axis_ + axis_left, stream())}, axes};
}
bool Partition::is_equivalent(const Primitive& other) const {
@@ -2185,7 +2241,9 @@ std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
}
auto shape = shape_;
shape.insert(shape.begin() + kax, key.shape()[kax]);
if (kax >= 0) {
shape.insert(shape.begin() + kax, key.shape()[kax]);
}
auto get_dtype = [width = width_]() {
switch (width) {
@@ -2217,15 +2275,19 @@ std::pair<std::vector<array>, std::vector<int>> Reshape::vmap(
// Transpose the input so that the vmap dim is first.
auto& in = inputs[0];
auto ax = axes[0];
std::vector<int> reorder(in.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + ax);
reorder.insert(reorder.begin(), ax);
// Insert the vmap dim into the shape at the beginning.
auto out = transpose(in, reorder, stream());
shape_.insert(shape_.begin(), in.shape()[ax]);
// Reshape the transposed input to the new shape.
return {{reshape(out, shape_, stream())}, {0}};
if (ax >= 0) {
std::vector<int> reorder(in.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + ax);
reorder.insert(reorder.begin(), ax);
// Insert the vmap dim into the shape at the beginning.
auto out = transpose(in, reorder, stream());
shape_.insert(shape_.begin(), in.shape()[ax]);
// Reshape the transposed input to the new shape.
return {{reshape(out, shape_, stream())}, {0}};
} else {
return {{reshape(in, shape_, stream())}, {ax}};
}
}
std::vector<array> Reshape::vjp(
@@ -2349,9 +2411,11 @@ std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
const std::vector<int>& axes) {
auto ax = axes[0];
auto reduce_axes = axes_;
for (auto& rax : reduce_axes) {
if (rax >= ax) {
rax++;
if (ax >= 0) {
for (auto& rax : reduce_axes) {
if (rax >= ax) {
rax++;
}
}
}
auto& in = inputs[0];
@@ -2424,16 +2488,13 @@ std::pair<std::vector<array>, std::vector<int>> Scan::vmap(
auto& in = inputs[0];
auto out_dtype =
(in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype();
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {
{array(
in.shape(),
out_dtype,
std::make_unique<Scan>(
stream(),
reduce_type_,
axis_ + (axes[0] <= axis_),
reverse_,
inclusive_),
stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_),
{in})},
axes};
}
@@ -2698,9 +2759,11 @@ std::pair<std::vector<array>, std::vector<int>> Slice::vmap(
auto strides = strides_;
auto ax = axes[0];
auto& input = inputs[0];
start.insert(start.begin() + ax, 0);
stop.insert(stop.begin() + ax, input.shape(ax));
strides.insert(strides.begin() + ax, 1);
if (ax >= 0) {
start.insert(start.begin() + ax, 0);
stop.insert(stop.begin() + ax, input.shape(ax));
strides.insert(strides.begin() + ax, 1);
}
return {{slice(input, start, stop, strides, stream())}, {ax}};
}
@@ -2786,6 +2849,114 @@ bool Slice::is_equivalent(const Primitive& other) const {
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
}
std::pair<std::vector<array>, std::vector<int>> SliceUpdate::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 2);
assert(axes.size() == 2);
auto start = start_indices_;
auto stop = end_indices_;
auto strides = strides_;
auto src = inputs[0];
auto upd = inputs[1];
auto src_ax = axes[0];
auto upd_ax = axes[1];
// No vmapping needed
if (src_ax == -1 && upd_ax == -1) {
return {{slice_update(src, upd, start, stop, strides, stream())}, {-1}};
}
// Broadcast src
if (src_ax == -1) {
src = expand_dims(src, upd_ax, stream());
auto shape = src.shape();
shape[upd_ax] = upd.shape(upd_ax);
src = broadcast_to(src, shape, stream());
src_ax = upd_ax;
}
// Broadcast upd
if (upd_ax == -1) {
upd = expand_dims(upd, src_ax, stream());
upd_ax = src_ax;
}
if (src_ax != upd_ax) {
upd = moveaxis(upd, upd_ax, src_ax, stream());
}
start.insert(start.begin() + src_ax, 0);
stop.insert(stop.begin() + src_ax, src.shape(src_ax));
strides.insert(strides.begin() + src_ax, 1);
return {{slice_update(src, upd, start, stop, strides, stream())}, {src_ax}};
}
std::vector<array> SliceUpdate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
// Check inputs
assert(primals.size() == 2);
auto& cotan = cotangents[0];
auto& src = primals[0];
auto& upd = primals[1];
std::vector<array> vjps;
for (int num : argnums) {
// Vjp for source
if (num == 0) {
auto grad = slice_update(
cotan,
zeros_like(upd, stream()),
start_indices_,
end_indices_,
strides_,
stream());
vjps.push_back(grad);
}
// Vjp fpr updates
else {
auto grad =
slice(cotan, start_indices_, end_indices_, strides_, stream());
vjps.push_back(grad);
}
}
return vjps;
}
std::vector<array> SliceUpdate::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Check inputs
assert(primals.size() == 2);
return {slice_update(
tangents[0],
tangents[1],
start_indices_,
end_indices_,
strides_,
stream())};
}
bool SliceUpdate::is_equivalent(const Primitive& other) const {
const SliceUpdate& s_other = static_cast<const SliceUpdate&>(other);
return (
start_indices_ == s_other.start_indices_ &&
end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_);
}
std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -2796,7 +2967,7 @@ std::pair<std::vector<array>, std::vector<int>> Softmax::vmap(
// We are vectorizing over an axis other than the last one so keep the
// softmax axis unchanged
if (axes[0] < inputs[0].ndim() - 1) {
if (axes[0] >= 0 && axes[0] < inputs[0].ndim() - 1) {
softmax_axes.push_back(-1);
} else {
softmax_axes.push_back(-2);
@@ -2837,7 +3008,8 @@ std::pair<std::vector<array>, std::vector<int>> Sort::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{sort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes};
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{sort(inputs[0], axis_ + axis_left, stream())}, axes};
}
std::vector<array> Sort::vjp(
@@ -2867,8 +3039,8 @@ bool Sort::is_equivalent(const Primitive& other) const {
std::pair<std::vector<array>, std::vector<int>> Split::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {
{split(inputs[0], indices_, axis_ + (axes[0] <= axis_), stream())}, axes};
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes};
}
std::vector<array> Split::vjp(
@@ -2971,7 +3143,7 @@ bool Sqrt::is_equivalent(const Primitive& other) const {
std::pair<std::vector<array>, std::vector<int>> StopGradient::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {inputs, axes};
return {{stop_gradient(inputs[0], stream())}, axes};
};
std::vector<array> Subtract::vjp(
@@ -3093,12 +3265,14 @@ std::pair<std::vector<array>, std::vector<int>> Transpose::vmap(
assert(inputs.size() == 1);
assert(axes.size() == 1);
auto vdim = axes[0];
for (auto& dim : axes_) {
if (dim >= vdim) {
dim++;
if (vdim >= 0) {
for (auto& dim : axes_) {
if (dim >= vdim) {
dim++;
}
}
axes_.insert(axes_.begin() + vdim, vdim);
}
axes_.insert(axes_.begin() + vdim, vdim);
return {{transpose(inputs[0], axes_, stream())}, {vdim}};
}
@@ -3107,4 +3281,35 @@ bool Transpose::is_equivalent(const Primitive& other) const {
return axes_ == t_other.axes_;
}
std::pair<std::vector<array>, std::vector<int>> NumberOfElements::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
std::vector<int> new_axes = axes_;
auto vdim = axes[0];
if (vdim >= 0) {
for (auto& dim : new_axes) {
if (dim >= vdim) {
dim++;
}
}
}
array out = array(
std::vector<int>{},
dtype_,
std::make_unique<NumberOfElements>(stream(), new_axes, inverted_, dtype_),
inputs);
return {{out}, {-1}};
}
bool NumberOfElements::is_equivalent(const Primitive& other) const {
const NumberOfElements& n_other = static_cast<const NumberOfElements&>(other);
return axes_ == n_other.axes_ && inverted_ == n_other.inverted_ &&
dtype_ == n_other.dtype_;
}
} // namespace mlx::core

View File

@@ -200,6 +200,7 @@ class AddMM : public UnaryPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_VMAP()
DEFINE_PRINT(AddMM)
bool is_equivalent(const Primitive& other) const override;
@@ -1140,6 +1141,7 @@ class Matmul : public UnaryPrimitive {
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_VMAP()
DEFINE_PRINT(Matmul)
DEFINE_DEFAULT_IS_EQUIVALENT()
};
@@ -1229,6 +1231,37 @@ class NotEqual : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class NumberOfElements : public UnaryPrimitive {
public:
explicit NumberOfElements(
Stream stream,
std::vector<int> axes,
bool inverted,
Dtype dtype)
: UnaryPrimitive(stream),
axes_(std::move(axes)),
inverted_(inverted),
dtype_(dtype) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_PRINT(NumberOfElements)
bool is_equivalent(const Primitive& other) const override;
std::vector<std::vector<int>> output_shapes(
const std::vector<array>& inputs) override {
return {{}};
}
private:
std::vector<int> axes_;
bool inverted_;
Dtype dtype_;
void eval(const std::vector<array>& inputs, array& out);
};
class Pad : public UnaryPrimitive {
public:
explicit Pad(
@@ -1359,6 +1392,14 @@ class Reshape : public UnaryPrimitive {
std::vector<int> shape_;
void eval(const std::vector<array>& inputs, array& out);
std::pair<bool, std::vector<size_t>> prepare_reshape(
const array& in,
const array& out);
void shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out);
};
class Reduce : public UnaryPrimitive {
@@ -1619,6 +1660,44 @@ class Slice : public UnaryPrimitive {
std::vector<int> strides_;
void eval(const std::vector<array>& inputs, array& out);
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in);
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out);
};
class SliceUpdate : public UnaryPrimitive {
public:
explicit SliceUpdate(
Stream stream,
const std::vector<int>& start_indices,
const std::vector<int>& end_indices,
const std::vector<int>& strides)
: UnaryPrimitive(stream),
start_indices_(start_indices),
end_indices_(end_indices),
strides_(strides){};
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(SliceUpdate)
bool is_equivalent(const Primitive& other) const override;
private:
std::vector<int> start_indices_;
std::vector<int> end_indices_;
std::vector<int> strides_;
void eval(const std::vector<array>& inputs, array& out);
std::tuple<int64_t, std::vector<int64_t>> prepare_slice(const array& in);
};
class Softmax : public UnaryPrimitive {
@@ -1840,4 +1919,36 @@ class QRF : public Primitive {
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
/* SVD primitive. */
class SVD : public Primitive {
public:
explicit SVD(Stream stream) : Primitive(stream){};
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_VMAP()
DEFINE_PRINT(SVD)
private:
void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
};
/* Matrix inversion primitive. */
class Inverse : public UnaryPrimitive {
public:
explicit Inverse(Stream stream) : UnaryPrimitive(stream){};
void eval_cpu(const std::vector<array>& inputs, array& output) override;
void eval_gpu(const std::vector<array>& inputs, array& output) override;
DEFINE_VMAP()
DEFINE_PRINT(Inverse)
private:
void eval(const std::vector<array>& inputs, array& output);
};
} // namespace mlx::core

View File

@@ -17,6 +17,9 @@
namespace mlx::core {
// Maximum allowed graph depth for eval
constexpr uint32_t max_graph_depth = 100'000;
/* This class is only meant to be used in eval
* for synchronizing with the main thread. */
class Synchronizer : public Primitive {
@@ -34,8 +37,8 @@ class Synchronizer : public Primitive {
// are currently under a function transformation.
int detail::InTracing::tracing_counter{0};
void eval(const std::vector<array>& outputs) {
std::function<void(const array&, bool)> recurse;
void eval(std::vector<array> outputs) {
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
@@ -49,59 +52,35 @@ void eval(const std::vector<array>& outputs) {
}
}
auto synchronizer =
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
auto synchronizer = array(
{}, bool_, std::make_unique<Synchronizer>(stream), std::move(outputs));
size_t depth_counter = 0;
recurse = [&](const array& a) {
if (depth_counter > max_graph_depth) {
throw std::runtime_error(
"[eval] Graph depth exceeded maximum allowed limit."
" Try evaluating the graph more frequently.");
}
recurse = [&](const array& a, bool largest_branch_first) {
auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
// If the input is being computed on a different stream, we need to manage
// the dependency.
auto check_dependency = [&](const array& in) {
// Recurse to the largest or smallest branch first.
depth_counter++;
for (auto& in : a.inputs()) {
recurse(in);
if (!in.is_evaled()) {
// If the input is being computed on a different stream, we need to
// manage the dependency.
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.primitive_id(), std::shared_future<void>{}});
}
}
};
// Recurse to the largest or smallest branch first.
size_t num_inputs = a.inputs().size();
if (num_inputs == 1) {
auto& in = a.inputs()[0];
recurse(in, true);
check_dependency(in);
} else if (num_inputs == 2) {
auto depth_1 = a.inputs()[0].graph_depth();
auto depth_2 = a.inputs()[1].graph_depth();
auto& in1 = a.inputs()[static_cast<int>(
!((depth_1 > depth_2) == largest_branch_first))];
auto& in2 = a.inputs()[static_cast<int>(
((depth_1 > depth_2) == largest_branch_first))];
recurse(in1, true);
check_dependency(in1);
recurse(in2, true);
check_dependency(in2);
} else if (num_inputs > 2) {
std::vector<int> recursion_order(a.inputs().size());
std::iota(recursion_order.begin(), recursion_order.end(), 0);
std::sort(
recursion_order.begin(),
recursion_order.end(),
[&a, largest_branch_first](int i, int j) {
auto depth_i = a.inputs()[i].graph_depth();
auto depth_j = a.inputs()[j].graph_depth();
return largest_branch_first ? depth_i > depth_j : depth_j < depth_i;
});
for (int idx : recursion_order) {
auto& in = a.inputs()[idx];
recurse(in, true);
check_dependency(in);
}
}
depth_counter--;
cache.insert(id);
for (auto& s : a.siblings()) {
@@ -116,7 +95,7 @@ void eval(const std::vector<array>& outputs) {
}
};
recurse(synchronizer, false);
recurse(synchronizer);
uintptr_t synch_id = synchronizer.primitive_id();
deps.insert({synch_id, std::shared_future<void>{}});
@@ -139,7 +118,7 @@ void eval(const std::vector<array>& outputs) {
arr_deps.push_back(it->second);
}
}
std::shared_ptr<std::promise<void>> p{nullptr};
std::shared_ptr<std::promise<void>> p;
if (auto it = deps.find(arr.primitive_id()); it != deps.end()) {
p = std::make_unique<std::promise<void>>();
ps.push_back(p);
@@ -674,7 +653,9 @@ std::vector<array> vmap_replace(
v_axes.push_back(-1);
}
}
auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
// For each primitive's outputs add its id, the vout id and the vax
auto outputs = a.outputs();
for (int i = 0; i < v_outputs.size(); ++i) {

View File

@@ -6,10 +6,10 @@
namespace mlx::core {
void eval(const std::vector<array>& outputs);
void eval(std::vector<array> outputs);
template <typename... Arrays>
void eval(Arrays... outputs) {
void eval(Arrays&&... outputs) {
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
}

View File

@@ -1,5 +1,7 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
namespace mlx::core::detail {
std::pair<std::vector<array>, std::vector<array>> vmap_trace(

View File

@@ -7,9 +7,28 @@
namespace mlx::core {
struct complex64_t;
struct complex128_t;
template <typename T>
static constexpr bool can_convert_to_complex64 =
inline constexpr bool can_convert_to_complex128 =
!std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
struct complex128_t : public std::complex<double> {
complex128_t(double v, double u) : std::complex<double>(v, u){};
complex128_t(std::complex<double> v) : std::complex<double>(v){};
template <
typename T,
typename = typename std::enable_if<can_convert_to_complex128<T>>::type>
complex128_t(T x) : std::complex<double>(x){};
operator float() const {
return real();
};
};
template <typename T>
inline constexpr bool can_convert_to_complex64 =
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
struct complex64_t : public std::complex<float> {

View File

@@ -329,4 +329,13 @@ std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
return os;
}
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
} // namespace mlx::core

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
@@ -76,6 +76,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
}

View File

@@ -1,3 +1,7 @@
[build-system]
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"]
requires = [
"setuptools>=42",
"nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f",
"cmake>=3.24",
]
build-backend = "setuptools.build_meta"

View File

@@ -0,0 +1,9 @@
import platform
if platform.system() == "Darwin":
version = tuple(map(int, platform.mac_ver()[0].split(".")))
major, minor = version[0], version[1]
if (major, minor) < (13, 5):
raise ImportError(
f"Only macOS 13.5 and newer are supported, not {major}.{minor}"
)

Some files were not shown because too many files have changed in this diff Show More