Compare commits

..

58 Commits

Author SHA1 Message Date
Angelos Katharopoulos
6fc00d2c10 Add rudimentary barrier 2024-11-05 11:34:55 -08:00
Angelos Katharopoulos
44f0de2854 Fix run without distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
29ec3539ed TCP socket distributed 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e94f0028c3 Change the send message size 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
e5354fcddb Make it work even for donated inputs 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
34dd079a64 Start a sockets based distributed backend 2024-11-05 11:27:41 -08:00
Angelos Katharopoulos
c3ccd4919f Add MPI barrier 2024-11-05 11:26:53 -08:00
Alex Barron
26be608470 Add split_k qvm for long context (#1564)
* Add splitk qvm

* configurable splitk

* tuning

* remove extra instantiation

* remove refactor

* separate test

* cpu tolerance
2024-11-05 11:25:19 -08:00
Angelos Katharopoulos
248431eb3c Reductions update (#1351) 2024-11-04 22:25:16 -08:00
Awni Hannun
76f275b4df error in rms for wrong size (#1562) 2024-11-04 13:24:02 -08:00
Awni Hannun
f1951d6cce Use fewer barriers (#1561)
* use fewer barriers

* comment
2024-11-04 10:26:49 -08:00
Angelos Katharopoulos
62f297b51d Sdpa fix (#1558) 2024-11-02 21:25:46 -07:00
Awni Hannun
09bc32f62f No extra reshape (#1557)
* no extra reshape

* lint
2024-11-02 19:07:20 -07:00
Chris Offner
46d8b16ab4 Fix vmap example in docs (#1556) 2024-11-02 17:44:14 -07:00
Chris Offner
42533931fa Fix typo "it's" -> "its" (#1555) 2024-11-02 06:06:34 -07:00
Awni Hannun
9bd3a7102f add python 3.13 to circle (#1553) 2024-11-01 20:55:35 -07:00
Alex Barron
9e516b71ea Add dispatchThreads to custom kernel doc (#1551)
* add dispatchThreads info

* update

* add link
2024-11-01 13:07:48 -07:00
Awni Hannun
eac961ddb1 patch (#1550) 2024-10-31 16:10:14 -07:00
Awni Hannun
57c6aa7188 fix multi output leak (#1548) 2024-10-31 09:32:01 -07:00
Awni Hannun
cde5b4ad80 patch (#1546) 2024-10-30 19:31:22 -07:00
Awni Hannun
4f72c66911 improvements to scatter / gather (#1541) 2024-10-30 19:30:54 -07:00
Jagrit Digani
960e3f0f05 Gemm update (#1518) 2024-10-30 19:30:28 -07:00
Awni Hannun
884af42da2 Fix thread group for large arrays (#1543)
* fix thread group for large arrays

* comment

* one more
2024-10-30 16:25:12 -07:00
Alex Barron
048fabdabd Fix vmap constant output size (#1524)
* use inputs to determine output size

* remove noop vmap tests
2024-10-30 16:16:53 -07:00
Léo
917252a5a1 Add favicon to docs (#1545)
* add sphinx's html_favicon config

* removed unneeded newline

* ran pre-commit hooks
2024-10-30 13:54:13 -07:00
Carlo Cabrera
1a992e31e8 Skip using Residency sets in VMs (#1537)
* Skip using Residency sets in VMs

Attempting to use residency sets in a VM throws[^1]

    libc++abi: terminating due to uncaught exception of type std::runtime_error: [metal::Device] Unable to construct residency set.

Not quite sure if this is the best fix, but it does make the error go
away.

Note that it was previously possible to run simple programs that used
mlx in a VM prior to 0eb56d5be0. See
related discussion at Homebrew/homebrew-core#195627.

[^1]: https://github.com/Homebrew/homebrew-core/actions/runs/11525831492/job/32105148462#step:3:56

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* change residency check

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-29 19:37:23 -07:00
Awni Hannun
d2ff04a4f2 fix format (#1539) 2024-10-28 18:29:14 -07:00
Awni Hannun
015c247393 change wino dispatch conditoin (#1534) 2024-10-28 11:13:44 -07:00
Awni Hannun
d3cd26820e Faster bits and bernoulli (#1535)
* faster bits and bernoulli

* fix bernoulli
2024-10-28 11:11:00 -07:00
Awni Hannun
91f6c499d7 fix (#1529) 2024-10-25 19:25:35 -07:00
Awni Hannun
35e9c87ab9 patch bump (#1528) 2024-10-25 13:13:23 -07:00
Awni Hannun
8e88e30d95 BFS graph evaluation order (#1525)
* bfs order

* try fix event issue
2024-10-25 10:27:19 -07:00
Awni Hannun
0eb56d5be0 Wired (#1510)
* expose residency sets as wire/unwire

* returns wired size

* fix

* runtime support check

* fix os check

* fix test

* fix no metal build

* docs

* nit

* nits in docs

* nits
2024-10-25 09:35:33 -07:00
Paul Hansel
f70764a162 Fix typo in build docs (#1522) 2024-10-24 20:55:06 -07:00
Awni Hannun
dad1b00b13 fix (#1523) 2024-10-24 19:17:46 -07:00
Venkata Naga Aditya Datta Chivukula
430ffef58a [Feature] Added Sparse Initialization (#1498)
Co-authored-by: Saanidhyavats <saanidhyavats@gmail.com>
2024-10-24 12:31:24 -07:00
Alex Barron
3d17077187 Add mx.array.__format__ (#1521)
* add __format__

* actually test something

* fix
2024-10-24 11:11:39 -07:00
Angelos Katharopoulos
c9b41d460f Working 64-bit scans (#1506) 2024-10-24 11:05:46 -07:00
xnorai
32972a5924 C++20 compatibility for fmt (#1519)
* C++20 compatibility for fmt

* Address review feedback

* Remove stray string

* Add newlines back
2024-10-24 08:54:51 -07:00
Dhruv Govil
f6afb9c09b Remove use of vector<const T> (#1514) 2024-10-22 16:31:52 -07:00
Kashif Rasul
3ddc07e936 Eigenvalues and eigenvectors (#1334)
* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-22 12:18:48 -07:00
Awni Hannun
c26208f67d Remove Hazard tracking with Fences (#1509)
* remove hazard tracking

* with fence map

* no hazard tracking with fences

* nits

* fix fence retain

* cleanup

* fix quantized rebase
2024-10-21 19:33:32 -07:00
Alex Barron
d15fa13daf Batched Quantized Matmul + Fast Small QMV (#1503)
* add fast qmv for small dims

* fix test

* batched cpu

* add batched template param

* refactor metal quantized.cpp
2024-10-21 16:23:17 -07:00
Awni Hannun
58a855682c v0.19.0 (#1502) 2024-10-18 11:55:18 -07:00
Awni Hannun
92d7cb71f8 Fix compile (#1501)
* fix compile

* fix space
2024-10-18 11:06:40 -07:00
Angelos Katharopoulos
50d8bed468 Fused attention for single query (#1497) 2024-10-18 00:58:52 -07:00
Awni Hannun
9dd72cd421 fix gumbel (#1495) 2024-10-17 13:52:39 -07:00
Awni Hannun
343aa46b78 No more 3.8 (#1493) 2024-10-16 17:51:38 -07:00
Awni Hannun
b8ab89b413 Docs in ci (#1491)
* docs in circle
2024-10-15 17:40:00 -07:00
Awni Hannun
f9f8c167d4 fix submodule stubs (#1492) 2024-10-15 16:23:37 -07:00
Awni Hannun
3f86399922 Real and Imag (#1490)
* real and imag

* fix

* fix
2024-10-15 16:23:15 -07:00
LastWhisper
2b8ace6a03 Typing the dropout. (#1479) 2024-10-15 06:45:46 -07:00
Awni Hannun
0ab8e099e8 Fix cpu segfault (#1488)
* fix cpu segfault

* nit in tests
2024-10-14 16:17:03 -07:00
Awni Hannun
020f048cd0 A few updates for CPU (#1482)
* some updates

* format

* fix

* nit
2024-10-14 12:45:49 -07:00
Awni Hannun
881615b072 Faster metal compiled kernels + some fixes (#1486)
* bump mac tests to use py39

* work per thread for compiled kernels

* fixe for large arrays

* fix
2024-10-14 12:45:38 -07:00
Awni Hannun
0eef4febfd bump mac tests to use py39 (#1485) 2024-10-14 10:40:32 -07:00
Awni Hannun
b54a70ec2d Make push button linux distribution (#1476)
* try again

* try again

* try again

* try again

* try again

* try again

* try again

* try again

* .circleci/config.yml

* one more fix

* nit
2024-10-14 06:21:44 -07:00
Awni Hannun
bf6ec92216 Make the GPU device more thread safe (#1478)
* gpu stream safety

* comment

* fix
2024-10-12 17:49:15 -07:00
153 changed files with 5725 additions and 3337 deletions

View File

@@ -13,8 +13,62 @@ parameters:
test_release:
type: boolean
default: false
linux_release:
type: boolean
default: false
jobs:
build_documentation:
parameters:
upload-docs:
type: boolean
default: false
macos:
xcode: "15.2.0"
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
name: Install
command: |
brew install python@3.9
brew install doxygen
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
pip install -r docs/requirements.txt
CMAKE_BUILD_PARALLEL_LEVEL=`sysctl -n hw.ncpu` pip install . -v
- when:
condition:
not: << parameters.upload-docs >>
steps:
- run:
name: Build documentation
command: |
source env/bin/activate
cd docs && doxygen && make html O=-W
- when:
condition: << parameters.upload-docs >>
steps:
- add_ssh_keys:
fingerprints:
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
- run:
name: Upload documentation
command: |
source env/bin/activate
git config user.email "mlx@group.apple.com"
git config user.name "CircleCI Docs"
git checkout gh-pages
git rebase main
cd docs
git rm -rf build/html
doxygen && make html O=-W
git add -f build/html
git commit -m "rebase"
git push -f origin gh-pages
linux_build_and_test:
docker:
- image: cimg/python:3.9
@@ -77,9 +131,9 @@ jobs:
- run:
name: Install dependencies
command: |
brew install python@3.8
brew install python@3.9
brew install openmpi
python3.8 -m venv env
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install --upgrade cmake
@@ -208,7 +262,7 @@ jobs:
- store_artifacts:
path: dist/
build_linux_test_release:
build_linux_release:
parameters:
python_version:
type: string
@@ -243,6 +297,7 @@ jobs:
pip install auditwheel
pip install patchelf
pip install build
pip install twine
<< parameters.extra_env >> \
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
pip install . -v
@@ -253,6 +308,11 @@ jobs:
python -m build --wheel
auditwheel show dist/*
auditwheel repair dist/* --plat manylinux_2_31_x86_64
- run:
name: Upload package
command: |
source env/bin/activate
twine upload wheelhouse/*
- store_artifacts:
path: wheelhouse/
@@ -272,6 +332,7 @@ workflows:
parameters:
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
- linux_build_and_test
- build_documentation
build_pypi_release:
when:
@@ -288,9 +349,17 @@ workflows:
ignore: /.*/
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
build_env: ["PYPI_RELEASE=1"]
- build_documentation:
filters:
tags:
only: /^v.*/
branches:
ignore: /.*/
upload-docs: true
prb:
when:
matches:
@@ -317,7 +386,7 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0"]
weekly_build:
when:
@@ -328,17 +397,17 @@ workflows:
- build_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
xcode_version: ["15.0.0", "15.2.0", "16.0.0"]
build_env: ["DEV_RELEASE=1"]
linux_test_release:
when:
and:
- equal: [ main, << pipeline.git.branch >> ]
- << pipeline.parameters.test_release >>
- << pipeline.parameters.linux_release >>
jobs:
- build_linux_test_release:
- build_linux_release:
matrix:
parameters:
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
extra_env: ["PYPI_RELEASE=1"]

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.18.1)
set(MLX_VERSION 0.19.3)
endif()
# --------------------- Processor tests -------------------------

View File

@@ -6,7 +6,7 @@
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
MLX is an array framework for machine learning research on Apple silicon,
MLX is an array framework for machine learning on Apple silicon,
brought to you by Apple machine learning research.
Some key features of MLX include:

View File

@@ -144,6 +144,13 @@ def reduction(op, axis, x):
mx.eval(ys)
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
mx.eval(z)
def softmax(axis, x):
ys = []
for i in range(100):
@@ -505,5 +512,8 @@ if __name__ == "__main__":
elif args.benchmark == "selu":
print(bench(selu, x))
elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))
else:
raise ValueError("Unknown benchmark")

View File

@@ -9,7 +9,7 @@ from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx):
dst[*idx] = x
dst[tuple(idx)] = x
mx.eval(dst)
idx = []
@@ -23,8 +23,8 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def gather(dst, x, idx, device):
dst[*idx] = x
def scatter(dst, x, idx, device):
dst[tuple(idx)] = x
if device == torch.device("mps"):
torch.mps.synchronize()
@@ -34,7 +34,7 @@ def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
runtime = measure_runtime(gather, dst=dst, x=x, idx=idx, device=device)
runtime = measure_runtime(scatter, dst=dst, x=x, idx=idx, device=device)
print(f"PyTorch: {runtime:.3f}ms")
@@ -54,7 +54,7 @@ if __name__ == "__main__":
(100_000, 64),
(1_000_000, 64),
(100_000,),
(2_000_00,),
(200_000,),
(20_000_000,),
(10000, 64),
(100, 64),
@@ -91,6 +91,6 @@ if __name__ == "__main__":
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20)
print(f"X {x_shape}, Indices {idx_shape}")
print(f"Dst: {dst_shape}, X {x_shape}, Indices {idx_shape}")
benchmark_scatter_mlx(dst_shape, x_shape, idx_shape)
benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device=device)

View File

@@ -0,0 +1,49 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
L = 1024
H = 32
H_k = 32 // 4
D = 128
def attention(q, k, v):
B, Hq, L, D = q.shape
_, Hk, S, _ = k.shape
q = q.reshape(B, Hk, Hq // Hk, L, D)
k = k[:, :, None, :, :]
v = v[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ v
return o.reshape(B, Hq, L, D)
def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
def time_self_attention_primitives():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(attention, q, k, v)
def time_self_attention_sdpa():
mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D))
k = mx.random.uniform(shape=(1, H_k, L, D))
v = mx.random.uniform(shape=(1, H_k, L, D))
mx.eval(q, k, v)
time_fn(sdpa, q, k, v)
if __name__ == "__main__":
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@@ -60,6 +60,7 @@ html_theme_options = {
},
}
html_favicon = html_theme_options["logo"]["image_light"]
# -- Options for HTMLHelp output ---------------------------------------------

View File

@@ -1,3 +1,5 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
@@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Using Shape/Strides

View File

@@ -14,7 +14,7 @@ silicon computer is
To install from PyPI you must meet the following requirements:
- Using an M series chip (Apple silicon)
- Using a native Python >= 3.8
- Using a native Python >= 3.9
- macOS >= 13.5
.. note::
@@ -240,7 +240,7 @@ x86 Shell
.. _build shell:
If the ouptut of ``uname -p`` is ``x86`` then your shell is running as x86 via
If the output of ``uname -p`` is ``x86`` then your shell is running as x86 via
Rosetta instead of natively.
To fix this, find the application in Finder (``/Applications`` for iTerm,
@@ -264,4 +264,4 @@ Also check that cmake is using the correct architecture:
If you see ``"x86_64"``, try re-installing ``cmake``. If you see ``"arm64"``
but the build errors out with "Building for x86_64 on macOS is not supported."
wipe your build cahce with ``rm -rf build/`` and try again.
wipe your build cache with ``rm -rf build/`` and try again.

View File

@@ -16,3 +16,5 @@ Linear Algebra
cross
qr
svd
eigvalsh
eigh

View File

@@ -14,6 +14,7 @@ Metal
get_cache_memory
set_memory_limit
set_cache_limit
set_wired_limit
clear_cache
start_capture
stop_capture

View File

@@ -80,6 +80,7 @@ Operations
greater_equal
hadamard_transform
identity
imag
inner
isfinite
isclose
@@ -125,6 +126,7 @@ Operations
quantize
quantized_matmul
radians
real
reciprocal
remainder
repeat

View File

@@ -33,12 +33,12 @@ Let's start with a simple example:
# Compile the function
compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32)
# Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same
up to numerical precision.
The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled
@@ -96,7 +96,7 @@ element-wise operations:
.. code-block:: python
def gelu(x):
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you
@@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging
---------
@@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
print(fun(mx.array(1.0)))
Compiling Training Graphs
Compiling Training Graphs
-------------------------
This section will step through how to use :func:`compile` with a simple example
@@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
@@ -330,7 +323,7 @@ To start, here is the simple example without any compilation:
To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python
.. code-block:: python
import mlx.core as mx
import mlx.nn as nn
@@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled:
# The state that will be captured as input and output
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
@@ -410,7 +403,7 @@ Compiling transformed functions works just as expected:
In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`.
function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile`

View File

@@ -25,7 +25,7 @@ Here is a simple example:
The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do:
function. To get the second derivative you can do:
.. code-block:: shell
@@ -50,7 +50,7 @@ Automatic Differentiation
.. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit
graphs.
graphs.
.. note::
@@ -114,7 +114,7 @@ way to do that is the following:
def loss_fn(params, x, y):
w, b = params["weight"], params["bias"]
h = w * x + b
h = w * x + b
return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
@@ -132,7 +132,7 @@ way to do that is the following:
Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a
In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that.
@@ -161,19 +161,19 @@ A naive way to add the elements from two sets of vectors is with a loop:
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs.
where the vectorized axes should be in the outputs.
Let's time these two different versions:

View File

@@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
.. code-block:: shell
>>> arr = mx.arange(10)
>>> idx = mx.array([5, 7])
>>> idx = mx.array([5, 7])
>>> arr[idx]
array([5, 7], dtype=int32)
@@ -77,12 +77,12 @@ from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient.
Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which outputs
general, MLX has limited support for operations for which output
*shapes* are dependent on input *data*. Other examples of these types of
operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`.
In Place Updates
In Place Updates
----------------
In place updates to indexed arrays are possible in MLX. For example:

View File

@@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an
:func:`eval` is performed.
MLX uses lazy evaluation because it has some nice features, some of which we
describe below.
describe below.
Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -109,14 +109,14 @@ Here is a concrete example:
An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array.
Calling :func:`array.item` on a scalar array will also evaluate it. In the
example above, printing the loss (``print(loss)``) or adding the loss scalar to
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
these lines are before ``mx.eval(loss, model.parameters())`` then this
will be a partial evaluation, computing only the forward pass.

View File

@@ -3,10 +3,10 @@
Conversion to NumPy and Other Frameworks
========================================
MLX array supports conversion between other frameworks with either:
MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
Let's convert an array to NumPy and back.
@@ -66,7 +66,7 @@ even though no in-place operations on MLX memory are executed.
PyTorch
-------
.. warning::
.. warning::
PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now.

View File

@@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
and :func:`jvp` for Jacobian-vector products.
Use :func:`value_and_grad` to efficiently compute both a function's output and
gradient with respect to the function's input.
gradient with respect to the function's input.

View File

@@ -8,33 +8,33 @@ Saving and Loading Arrays
MLX supports multiple array serialization formats.
.. list-table:: Serialization Formats
:widths: 20 8 25 25
:widths: 20 8 25 25
:header-rows: 1
* - Format
- Extension
* - Format
- Extension
- Function
- Notes
* - NumPy
- ``.npy``
- Notes
* - NumPy
- ``.npy``
- :func:`save`
- Single arrays only
* - NumPy archive
- ``.npz``
* - NumPy archive
- ``.npz``
- :func:`savez` and :func:`savez_compressed`
- Multiple arrays
- Multiple arrays
* - Safetensors
- ``.safetensors``
- ``.safetensors``
- :func:`save_safetensors`
- Multiple arrays
* - GGUF
- ``.gguf``
- Multiple arrays
* - GGUF
- ``.gguf``
- :func:`save_gguf`
- Multiple arrays
The :func:`load` function will load any of the supported serialization
formats. It determines the format from the extensions. The output of
:func:`load` depends on the format.
:func:`load` depends on the format.
Here's an example of saving a single array to a file:

View File

@@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
In MLX, rather than moving arrays to devices, you specify the device when you
run the operation. Any device can perform any operation on ``a`` and ``b``
without needing to move them from one memory location to another. For example:
without needing to move them from one memory location to another. For example:
.. code-block:: python

View File

@@ -178,8 +178,10 @@ void array::move_shared_buffer(
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}
void array::move_shared_buffer(array other) {
@@ -269,6 +271,9 @@ array::ArrayDesc::~ArrayDesc() {
for (array& a : ad.inputs) {
if (a.array_desc_) {
input_map.insert({a.id(), a});
for (auto& s : a.siblings()) {
input_map.insert({s.id(), s});
}
}
}
ad.inputs.clear();

View File

@@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);

View File

@@ -18,49 +18,61 @@ void _qmm_t_4_64(
const float* biases,
int M,
int N,
int K) {
int K,
int B,
bool batched_w) {
constexpr int bits = 4;
constexpr int group_size = 64;
constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits;
constexpr int packs_in_group = group_size / pack_factor;
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
int w_els = N * K / pack_factor;
int g_els = w_els * pack_factor / group_size;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int i = 0; i < B; i++) {
for (int m = 0; m < M; m++) {
const uint32_t* w_local = w;
const float* scales_local = scales;
const float* biases_local = biases;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
for (int n = 0; n < N; n++) {
const simd_float16* x_local = (simd_float16*)x;
simd_float16 sum = 0;
for (int k = 0; k < K; k += group_size) {
float scale = *scales_local++;
float bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw += 2) {
// TODO: vectorize this properly
simd_uint16 wi;
for (int e = 0; e < 2; e++) {
uint32_t wii = *w_local++;
for (int p = 0; p < 8; p++) {
wi[e * 8 + p] = wii & bitmask;
wii >>= bits;
}
}
}
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
simd_float16 wf = simd_float(wi);
wf *= scale;
wf += bias;
sum += (*x_local) * wf;
x_local++;
sum += (*x_local) * wf;
x_local++;
}
}
*result = simd_reduce_add(sum);
result++;
}
*result = simd_reduce_add(sum);
result++;
x += K;
}
if (batched_w) {
w += w_els;
scales += g_els;
biases += g_els;
}
x += K;
}
}
@@ -82,8 +94,10 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (condition) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
int K = x.shape(-1);
int M = x.size() / K;
int M = x.shape(-2);
int N = out.shape(-1);
int B = x.size() / K / M;
bool batched_w = w.ndim() > 2;
_qmm_t_4_64(
out.data<float>(),
x.data<float>(),
@@ -92,7 +106,9 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
biases.data<float>(),
M,
N,
K);
K,
B,
batched_w);
} else {
eval(inputs, out);
}

View File

@@ -31,6 +31,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp

View File

@@ -2,46 +2,12 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
@@ -66,7 +32,14 @@ void cholesky_impl(const array& a, array& factor, bool upper) {
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N);
int info;
MLX_LAPACK_FUNC(spotrf)
(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how

View File

@@ -4,6 +4,8 @@
#include <filesystem>
#include <fstream>
#include <list>
#include <mutex>
#include <shared_mutex>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
@@ -12,22 +14,7 @@
namespace mlx::core {
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::string& source_code = "") {
struct CompilerCache {
struct DLib {
DLib(const std::string& libname) {
lib = dlopen(libname.c_str(), RTLD_NOW);
@@ -44,15 +31,41 @@ void* compile(
void* lib;
};
// Statics to cache compiled libraries and functions
static std::list<DLib> libs;
static std::unordered_map<std::string, void*> kernels;
if (auto it = kernels.find(kernel_name); it != kernels.end()) {
return it->second;
}
if (source_code.empty()) {
return nullptr;
std::list<DLib> libs;
std::unordered_map<std::string, void*> kernels;
std::shared_mutex mtx;
};
static CompilerCache cache{};
// GPU compile is always available if the GPU is available and since we are in
// this file CPU compile is also available.
namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
// Return a pointer to a compiled function
void* compile(
const std::string& kernel_name,
const std::function<std::string(void)>& source_builder) {
{
std::shared_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
}
std::unique_lock lock(cache.mtx);
if (auto it = cache.kernels.find(kernel_name); it != cache.kernels.end()) {
return it->second;
}
std::string source_code = source_builder();
std::string kernel_file_name;
// Deal with long kernel names. Maximum length for files on macOS is 255
@@ -90,8 +103,8 @@ void* compile(
source_file.close();
std::ostringstream build_command;
build_command << "g++ -std=c++17 -O2 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
<< source_file_path << "' -o '" << shared_lib_path << "'";
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
if (return_code) {
@@ -103,10 +116,10 @@ void* compile(
}
// load library
libs.emplace_back(shared_lib_path);
cache.libs.emplace_back(shared_lib_path);
// Load function
void* fun = dlsym(libs.back().lib, kernel_name.c_str());
void* fun = dlsym(cache.libs.back().lib, kernel_name.c_str());
if (!fun) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to load compiled function "
@@ -114,7 +127,7 @@ void* compile(
<< dlerror();
throw std::runtime_error(msg.str());
}
kernels.insert({kernel_name, fun});
cache.kernels.insert({kernel_name, fun});
return fun;
}
@@ -316,10 +329,7 @@ void Compiled::eval_cpu(
}
// Get the function
auto fn_ptr = compile(kernel_name);
// If it doesn't exist, compile it
if (fn_ptr == nullptr) {
auto fn_ptr = compile(kernel_name, [&]() {
std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl;
@@ -334,10 +344,8 @@ void Compiled::eval_cpu(
ndim);
// Close extern "C"
kernel << "}" << std::endl;
// Compile and get function pointer
fn_ptr = compile(kernel_name, kernel.str());
}
return kernel.str();
});
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, false);

View File

@@ -3,13 +3,8 @@
#include <cassert>
#include <numeric>
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"

View File

@@ -1,14 +1,10 @@
// Copyright © 2023-2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
@@ -114,6 +110,7 @@ DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(Eigh)
namespace {

117
mlx/backend/common/eigh.cpp Normal file
View File

@@ -0,0 +1,117 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
void ssyevd(
char jobz,
char uplo,
float* a,
int N,
float* w,
float* work,
int lwork,
int* iwork,
int liwork) {
int info;
MLX_LAPACK_FUNC(ssyevd)
(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ a,
/* lda = */ &N,
/* w = */ w,
/* work = */ work,
/* lwork = */ &lwork,
/* iwork = */ iwork,
/* liwork = */ &liwork,
/* info = */ &info);
if (info != 0) {
std::stringstream msg;
msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
}
} // namespace
void Eigh::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
const auto& a = inputs[0];
auto& values = outputs[0];
auto vectors = compute_eigenvectors_
? outputs[1]
: array(a.shape(), a.dtype(), nullptr, {});
values.set_data(allocator::malloc_or_wait(values.nbytes()));
copy(
a,
vectors,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
if (compute_eigenvectors_) {
// Set the strides and flags so the eigenvectors
// are in the columns of the output
auto flags = vectors.flags();
auto strides = vectors.strides();
auto ndim = a.ndim();
std::swap(strides[ndim - 1], strides[ndim - 2]);
if (a.size() > 1) {
flags.row_contiguous = false;
if (ndim > 2) {
flags.col_contiguous = false;
} else {
flags.col_contiguous = true;
}
}
vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size());
}
auto vec_ptr = vectors.data<float>();
auto eig_ptr = values.data<float>();
char jobz = compute_eigenvectors_ ? 'V' : 'N';
auto N = a.shape(-1);
// Work query
int lwork;
int liwork;
{
float work;
int iwork;
ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1);
lwork = static_cast<int>(work);
liwork = iwork;
}
auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)};
for (size_t i = 0; i < a.size() / (N * N); ++i) {
ssyevd(
jobz,
uplo_[0],
vec_ptr,
N,
eig_ptr,
static_cast<float*>(work_buf.buffer.raw_ptr()),
lwork,
static_cast<int*>(iwork_buf.buffer.raw_ptr()),
liwork);
vec_ptr += N * N;
eig_ptr += N;
}
}
} // namespace mlx::core

View File

@@ -2,39 +2,19 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1),
/* diag_len = */ static_cast<size_t>(1));
#else
strtri_(
MLX_LAPACK_FUNC(strtri)
(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}

View File

@@ -1,10 +1,11 @@
// Copyright © 2024 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#include <lapack.h>
#endif

View File

@@ -18,10 +18,12 @@ if [ "$CLANG" = "TRUE" ]; then
#include <cstdint>
#include <vector>
EOM
CC_FLAGS=""
else
CC_FLAGS="-std=c++17"
fi
CONTENT=$($GCC -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
CONTENT=$($GCC $CC_FLAGS -I "$SRCDIR" -E "$SRCDIR/mlx/backend/common/compiled_preamble.h" 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
const char* get_kernel_preamble() {

View File

@@ -1,15 +1,10 @@
// Copyright © 2024 Apple Inc.
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"

View File

@@ -295,6 +295,13 @@ struct Floor {
}
};
struct Imag {
template <typename T>
T operator()(T x) {
return std::imag(x);
}
};
struct Log {
template <typename T>
T operator()(T x) {
@@ -337,6 +344,13 @@ struct Negative {
}
};
struct Real {
template <typename T>
T operator()(T x) {
return std::real(x);
}
};
struct Round {
template <typename T>
T operator()(T x) {

View File

@@ -273,6 +273,10 @@ void Full::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype);
}
void Imag::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Imag());
}
void Log::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -398,6 +402,10 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
}
}
void Real::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_op<complex64_t, float>(inputs[0], out, detail::Real());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@@ -2,14 +2,9 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
template <typename T>

View File

@@ -201,55 +201,61 @@ void _qmm_dispatch(
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.size() / K;
int M = x.shape(-2);
int N = out.shape(-1);
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>(),
x.data<float>(),
w.data<uint32_t>(),
scales.data<float>(),
biases.data<float>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>(),
x.data<float16_t>(),
w.data<uint32_t>(),
scales.data<float16_t>(),
biases.data<float16_t>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>(),
x.data<bfloat16_t>(),
w.data<uint32_t>(),
scales.data<bfloat16_t>(),
biases.data<bfloat16_t>(),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
int w_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0;
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / x.shape(-1) / x.shape(-2);
for (int i = 0; i < batch_size; i++) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float>() + elem_to_loc(i * g_els, scales),
biases.data<float>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<float16_t>() + elem_to_loc(i * g_els, scales),
biases.data<float16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(i * M * K, x),
w.data<uint32_t>() + elem_to_loc(i * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(i * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(i * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}

View File

@@ -2,7 +2,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/backend/common/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {

View File

@@ -24,26 +24,26 @@ void set_unary_output_data(const array& in, array& out) {
}
}
template <typename T, typename Op>
void unary_op(const T* a, T* out, Op op, size_t shape, size_t stride) {
template <typename T, typename U = T, typename Op>
void unary_op(const T* a, U* out, Op op, size_t shape, size_t stride) {
for (size_t i = 0; i < shape; i += 1) {
out[i] = op(*a);
a += stride;
}
}
template <typename T, typename Op>
template <typename T, typename U = T, typename Op>
void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>();
if (a.flags().contiguous) {
set_unary_output_data(a, out);
T* dst = out.data<T>();
U* dst = out.data<U>();
for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
T* dst = out.data<T>();
U* dst = out.data<U>();
size_t shape = a.ndim() > 0 ? a.shape(-1) : 1;
size_t stride = a.ndim() > 0 ? a.strides(-1) : 1;
if (a.ndim() <= 1) {

View File

@@ -26,8 +26,8 @@ make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(reduce_utils kernels/atomic.h kernels/reduction/ops.h)
make_jit_source(scatter)
make_jit_source(gather)
make_jit_source(scatter kernels/indexing.h)
make_jit_source(gather kernels/indexing.h)
make_jit_source(hadamard)
if(MLX_METAL_JIT)
@@ -99,6 +99,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/resident.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
if(NOT MLX_METAL_PATH)

View File

@@ -2,6 +2,7 @@
#include "mlx/backend/metal/allocator.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/resident.h"
#include <mach/vm_page_size.h>
#include <unistd.h>
@@ -140,6 +141,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
residency_set_(device_),
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ =
@@ -148,6 +150,8 @@ MetalAllocator::MetalAllocator()
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_);
max_pool_size_ = block_limit_;
device(mlx::core::Device::gpu)
.set_residency_set(residency_set_.mtl_residency_set());
}
size_t MetalAllocator::set_cache_limit(size_t limit) {
@@ -164,6 +168,12 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
return limit;
};
size_t MetalAllocator::set_wired_limit(size_t limit) {
std::swap(limit, wired_limit_);
residency_set_.resize(wired_limit_);
return limit;
};
Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Metal doesn't like empty buffers
if (size == 0) {
@@ -205,7 +215,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked;
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
lk.unlock();
buf = device_->newBuffer(size, res_opt);
lk.lock();
@@ -220,6 +230,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
}
residency_set_.insert(buf);
return Buffer{static_cast<void*>(buf)};
}
@@ -231,6 +243,7 @@ void MetalAllocator::clear_cache() {
void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
std::unique_lock lk(mutex_);
residency_set_.erase(buf);
active_memory_ -= buf->length();
if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf);
@@ -246,15 +259,9 @@ size_t MetalAllocator::size(Buffer buffer) const {
}
MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
// not be called on exit and all the buffers will be leaked. This is necessary
// because releasing buffers can take more than 30sec when the program holds a
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
// users when exiting.
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
// when applying this pattern to more places, or when introducing sanitizers
// to MLX.
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
// By creating the |allocator_| on heap, the destructor of MetalAllocator
// will not be called on exit and buffers in the cache will be leaked. This
// can save some time at program exit.
static MetalAllocator* allocator_ = new MetalAllocator;
return *allocator_;
}
@@ -265,6 +272,15 @@ size_t set_cache_limit(size_t limit) {
size_t set_memory_limit(size_t limit, bool relaxed /* = true */) {
return allocator().set_memory_limit(limit, relaxed);
}
size_t set_wired_limit(size_t limit) {
if (limit >
std::get<size_t>(device_info()["max_recommended_working_set_size"])) {
throw std::invalid_argument(
"[metal::set_wired_limit] Setting a wired limit larger than "
"the maximum working set size is not allowed.");
}
return allocator().set_wired_limit(limit);
}
size_t get_active_memory() {
return allocator().get_active_memory();
}

View File

@@ -8,6 +8,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/resident.h"
namespace mlx::core::metal {
@@ -72,6 +73,7 @@ class MetalAllocator : public allocator::Allocator {
};
size_t set_cache_limit(size_t limit);
size_t set_memory_limit(size_t limit, bool relaxed);
size_t set_wired_limit(size_t limit);
void clear_cache();
private:
@@ -82,12 +84,15 @@ class MetalAllocator : public allocator::Allocator {
// Caching allocator
BufferCache buffer_cache_;
ResidencySet residency_set_;
// Allocation stats
size_t block_limit_;
size_t gc_limit_;
size_t active_memory_{0};
size_t peak_memory_{0};
size_t max_pool_size_;
size_t wired_limit_{0};
bool relaxed_{true};
std::mutex mutex_;

View File

@@ -1,5 +1,4 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
@@ -84,8 +83,7 @@ void binary_op_gpu_inplace(
bool use_2d = out.data_size() > UINT32_MAX;
auto ndim = shape.size();
int work_per_thread =
(bopt == BinaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1;
std::string kernel_name =
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread);
auto& d = metal::device(s.device);
@@ -111,6 +109,7 @@ void binary_op_gpu_inplace(
compute_encoder.set_output_array(outputs[1], arg_idx++);
}
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (bopt == BinaryOpType::General) {
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
@@ -133,7 +132,6 @@ void binary_op_gpu_inplace(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
}
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
@@ -143,13 +141,12 @@ void binary_op_gpu_inplace(
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -13,6 +13,8 @@
namespace mlx::core {
constexpr int WORK_PER_THREAD = 4;
inline void build_kernel(
std::ostream& os,
const std::string& kernel_name,
@@ -23,7 +25,8 @@ inline void build_kernel(
bool contiguous,
int ndim,
bool dynamic_dims,
bool use_big_index = false) {
bool use_big_index = false,
int work_per_thread = 1) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
@@ -38,8 +41,8 @@ inline void build_kernel(
int cnt = 0;
// Start the kernel
os << "[[host_name(\"" << kernel_name << "\")]]" << std::endl
<< "[[kernel]] void " << kernel_name << "(" << std::endl;
os << "[[host_name(\"" << kernel_name << "\")]]\n"
<< "[[kernel]] void " << kernel_name << "(\n";
// Add the input arguments
for (auto& x : inputs) {
@@ -53,11 +56,11 @@ inline void build_kernel(
// Scalars and contiguous need no strides
if (is_scalar(x) || contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
<< " [[buffer(" << cnt++ << ")]],\n";
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
<< " [[buffer(" << cnt++ << ")]],\n";
}
}
@@ -69,58 +72,37 @@ inline void build_kernel(
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]]," << std::endl;
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n";
}
// Add output strides and shape to extract the indices.
if (!contiguous) {
os << " constant const size_t* output_strides [[buffer(" << cnt++
<< ")]]," << std::endl
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],"
<< std::endl;
<< ")]],\n"
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n";
}
if (dynamic_dims) {
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],"
<< std::endl;
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n";
}
// The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]]," << std::endl
<< " uint3 grid [[threads_per_grid]]) {" << std::endl;
os << " uint3 pos [[thread_position_in_grid]],\n"
<< " uint3 grid [[threads_per_grid]]) {\n";
if (use_big_index) {
// This is only used for contiguous kernels which don't have
// a third grid dimension
os << " size_t index = pos.x + grid.x * size_t(pos.y);";
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
} else if (work_per_thread > 1) {
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
<< " int xshape = output_shape["
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
} else {
os << " uint index = pos.x + grid.x * (pos.y + grid.y * pos.z);";
}
os << std::endl;
// Extract the indices per axis to individual uints if we have arrays that
// are broadcasted or transposed
if (add_indices) {
if (!dynamic_dims) {
if (ndim == 1) {
os << " uint index_0 = pos.x;" << std::endl;
} else if (ndim == 2) {
os << " uint index_0 = pos.y;" << std::endl
<< " uint index_1 = pos.x;" << std::endl;
} else if (ndim == 3) {
os << " uint index_0 = pos.z;" << std::endl
<< " uint index_1 = pos.y;" << std::endl
<< " uint index_2 = pos.x;" << std::endl;
} else {
for (int i = 0; i < ndim - 2; i++) {
os << " uint index_" << i << " = (index / uint(output_strides[" << i
<< "])) % output_shape[" << i << "];" << std::endl;
}
os << " uint index_" << ndim - 2 << " = pos.y;" << std::endl
<< " uint index_" << ndim - 1 << " = pos.x;" << std::endl;
}
}
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
}
// Read the inputs in tmps
int nc_in_count = 0;
// Read constant / contiguous inputs in tmps
std::vector<array> nc_inputs;
for (int i = 0; i < inputs.size(); ++i) {
auto& x = inputs[i];
auto& xname = namer.get_name(x);
@@ -130,56 +112,117 @@ inline void build_kernel(
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
print_constant(os, x);
os << ");" << std::endl;
os << ");\n";
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
<< xname << "[0];\n";
} else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
int offset = nc_in_count * ndim;
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << "in_strides[" << offset << "]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
}
os << "];" << std::endl;
nc_in_count++;
<< xname << "[index];\n";
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
<< nc_in_count * ndim << ", ndim)];" << std::endl;
nc_in_count++;
nc_inputs.push_back(x);
}
}
// Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
if (ndim == 1) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, "
<< "in_strides[" << offset << "]);\n";
} else if (ndim == 2) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, "
<< "in_strides + " << offset << ");\n";
} else if (ndim == 3) {
int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_3(pos, "
<< "in_strides + " << offset << ");\n";
} else if (!dynamic_dims) {
int offset = i * ndim;
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
<< offset + ndim - 1 << "]"
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
} else {
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
<< i << " + ndim - 1]"
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
}
}
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
os << " uint zpos = pos.z;\n";
if (dynamic_dims) {
os << " for (int d = ndim - 3; d >= 0; --d) {\n";
} else {
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n";
}
os << " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]);
os << " index_" << xname << " += ";
if (dynamic_dims) {
os << "l * in_strides[" << i << " * ndim + d];\n";
} else {
os << "l * in_strides[" << i * ndim << " + d];\n";
}
}
os << " zpos /= output_shape[d];\n }\n";
}
// Open per-thread loop
if (work_per_thread > 1) {
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}
// Read non-contiguous inputs into tmps
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index_" << xname << "];\n";
}
// Actually write the computation
for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x)
<< " = ";
if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_"
<< namer.get_name(x.inputs()[0]) << ");" << std::endl;
<< namer.get_name(x.inputs()[0]) << ");\n";
} else {
x.primitive().print(os);
os << "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", ";
}
os << "tmp_" << namer.get_name(x.inputs().back()) << ");" << std::endl;
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n";
}
}
// Write the outputs from tmps
for (auto& x : outputs) {
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x)
<< ";" << std::endl;
<< ";\n";
}
// Increment indices and close per thread loop
if (work_per_thread > 1) {
for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i];
auto& xname = namer.get_name(x);
if (!dynamic_dims) {
os << " index_" << xname << " += "
<< "in_strides[" << i * ndim + ndim - 1 << "];\n";
} else {
os << " index_" << xname << " += "
<< "in_strides[" << i << " * ndim + ndim - 1];\n";
}
}
os << " index++;\n }\n";
}
// Finish the kernel
os << "}" << std::endl;
os << "}\n";
if (cnt > 31) {
std::ostringstream msg;
@@ -202,10 +245,7 @@ void Compiled::eval_gpu(
// Get the kernel if someone else built it already
auto& s = stream();
auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_);
// If not we have to build it ourselves
if (lib == nullptr) {
auto lib = d.get_library(kernel_lib_, [&]() {
std::ostringstream kernel;
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops();
@@ -240,7 +280,9 @@ void Compiled::eval_gpu(
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false);
/* dynamic_dims = */ false,
/* use_big_index = */ false,
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1);
}
build_kernel(
kernel,
@@ -251,10 +293,11 @@ void Compiled::eval_gpu(
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true);
lib = d.get_library(kernel_lib_, kernel.str());
}
/* dynamic_dims = */ true,
/* use_big_index = */ false,
/* work_per_thread = */ WORK_PER_THREAD);
return kernel.str();
});
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
@@ -378,21 +421,29 @@ void Compiled::eval_gpu(
// Launch the kernel
if (contiguous) {
size_t nthreads = outputs[0].data_size();
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
int pow2;
if (thread_group_size == 1024) {
pow2 = 10;
} else if (thread_group_size > 512) {
pow2 = 9;
} else {
throw std::runtime_error("[Metal::compiled] Must use > 512 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}

View File

@@ -752,10 +752,6 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
bool channels_large = (conv_params.C + conv_params.O) >= 512;
bool channels_med = (conv_params.C + conv_params.O) >= 256;
if (groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
@@ -769,10 +765,13 @@ void conv_2D_gpu(
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
(channels_large || (channels_med && inp_large))) {
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}
@@ -918,14 +917,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
}
// Clear copies
if (!copies.empty()) {
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
// Record copies
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -98,7 +98,7 @@ void copy_gpu_inplace(
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
} else if (shape[ndim - 1] >= 4) {
} else {
work_per_thread = 4;
kname << "n4";
}
@@ -120,6 +120,7 @@ void copy_gpu_inplace(
compute_encoder.set_input_array(donate_in ? out : in, 0, inp_offset);
compute_encoder.set_output_array(out, 1, out_offset);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()};
@@ -145,7 +146,6 @@ void copy_gpu_inplace(
}
// NB assuming thread_group_size is a power of 2 larger than 32 x 32
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::copy] Must use 1024 sized block");
}
@@ -155,13 +155,12 @@ void copy_gpu_inplace(
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@@ -205,14 +204,14 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}

View File

@@ -32,17 +32,15 @@ void CustomKernel::eval_gpu(
return copies.back();
}
};
std::vector<const array> checked_inputs;
std::vector<array> checked_inputs;
for (const array& in : inputs) {
checked_inputs.push_back(check_input(in));
}
auto& d = metal::device(s.device);
const auto& lib_name = name_;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
lib = d.get_library(lib_name, metal::utils() + source_);
}
auto lib =
d.get_library(lib_name, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
@@ -79,12 +77,7 @@ void CustomKernel::eval_gpu(
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder->dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core::fast

View File

@@ -20,7 +20,6 @@ namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
@@ -121,33 +120,29 @@ MTL::Library* load_library(
} // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc_->retain();
}
CommandEncoder::~CommandEncoder() {
enc->endEncoding();
enc->release();
enc_->endEncoding();
enc_->release();
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
enc_->setBuffer(a_buf, base_offset, idx);
}
void CommandEncoder::set_output_array(
@@ -156,55 +151,49 @@ void CommandEncoder::set_output_array(
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
if (concurrent_) {
concurrent_outputs_.insert(buf);
} else {
outputs.insert(buf);
next_outputs_.insert(buf);
}
}
void CommandEncoder::maybeInsertBarrier() {
if (needs_barrier_) {
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
needs_barrier_ = false;
prev_outputs_ = std::move(next_outputs_);
} else {
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
}
next_outputs_.clear();
}
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
maybeInsertBarrier();
enc_->dispatchThreadgroups(grid_dims, group_dims);
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
maybeInsertBarrier();
enc_->dispatchThreads(grid_dims, group_dims);
}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String());
}
Device::~Device() {
auto pool = new_scoped_memory_pool();
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
@@ -219,69 +208,134 @@ void Device::new_queue(int index) {
// Multiple threads can ask the device for queues
// We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
debug_set_stream_queue_label(q, index);
if (!q) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
}
queue_map_.insert({index, q});
stream_map_.emplace(index, q);
if (residency_set_ != nullptr) {
q->addResidencySet(residency_set_);
}
}
int Device::get_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
return bit->second.first;
return get_stream_(index).buffer_ops;
}
void Device::increment_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
bit->second.first++;
get_stream_(index).buffer_ops++;
}
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
if (bit == buffer_map_.end()) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
auto& stream = get_stream_(index);
if (stream.buffer == nullptr) {
stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
if (!stream.buffer) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
bit = buffer_map_.insert({index, {0, cb}}).first;
stream.buffer->retain();
}
return bit->second.second;
return stream.buffer;
}
void Device::commit_command_buffer(int index) {
auto bit = buffer_map_.find(index);
bit->second.second->commit();
bit->second.second->release();
buffer_map_.erase(bit);
auto& stream = get_stream_(index);
stream.buffer->commit();
stream.buffer->release();
stream.buffer = nullptr;
stream.buffer_ops = 0;
}
void Device::add_temporary(array arr, int index) {
get_stream_(index).temporaries.push_back(std::move(arr));
}
void Device::add_temporaries(std::vector<array> arrays, int index) {
if (arrays.empty()) {
return;
}
auto& stream = get_stream_(index);
stream.temporaries.insert(
stream.temporaries.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
}
void Device::end_encoding(int index) {
encoder_map_.erase(index);
auto& stream = get_stream_(index);
if (stream.encoder != nullptr) {
// Each command encoder has a unique fence. We also store a map of
// all previous outputs of command encoders to their corresponding fence.
// - The command encoder records its inputs and outputs.
// - Wait on a fence if any inputs in the encoder are outputs of a previous
// encoder.
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
if (t.data<void>() != nullptr) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
}
// Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
{
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto in : enc.inputs()) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc->waitForFence(it->second->fence);
waiting_on.insert(it->second);
}
}
}
for (auto out : enc.outputs()) {
stream.outputs[out] = stream.fence;
}
}
enc->updateFence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
fence = std::move(stream.fence),
outputs = std::move(enc.outputs()),
temporaries =
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
temporaries.clear();
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto o : outputs) {
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
if (it->second == fence) {
stream.outputs.erase(it);
}
}
}
});
}
stream.encoder = nullptr;
}
CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
eit =
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
auto& stream = get_stream_(index);
if (stream.encoder == nullptr) {
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
stream.fence = std::make_shared<Fence>(device_->newFence());
}
return *(eit->second);
return *stream.encoder;
}
void Device::register_library(
@@ -293,20 +347,7 @@ void Device::register_library(
}
}
MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
// Search for cached metal lib
MTL::Library* mtl_lib;
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name, get_colocated_mtllib_path(lib_name));
mtl_lib = library_map_[lib_name];
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const std::string& source_string) {
MTL::Library* Device::build_library_(const std::string& source_string) {
auto pool = new_scoped_memory_pool();
auto ns_code =
@@ -322,26 +363,7 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build metal library from source" << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
return mtl_lib;
}
MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
auto pool = new_scoped_memory_pool();
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(desc, &error);
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
msg << "[metal::Device] Unable to build metal library from source\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@@ -465,68 +487,32 @@ MTL::ComputePipelineState* Device::get_kernel_(
return kernel;
}
MTL::Library* Device::get_library(const std::string& name) {
MTL::Library* Device::get_library_(const std::string& name) {
std::shared_lock lock(library_mtx_);
auto it = library_map_.find(name);
return (it != library_map_.end()) ? it->second : nullptr;
}
MTL::Library* Device::get_library(
const std::string& name,
const std::string& source,
bool cache /* = true */) {
if (cache) {
const std::function<std::string(void)>& builder) {
{
std::shared_lock rlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(source);
if (cache) {
library_map_.insert({name, mtl_lib});
std::unique_lock wlock(library_mtx_);
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
auto mtl_lib = build_library_(builder());
library_map_.insert({name, mtl_lib});
return mtl_lib;
}
MTL::Library* Device::get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache /* = true */) {
if (cache) {
if (auto it = library_map_.find(name); it != library_map_.end()) {
return it->second;
}
}
auto mtl_lib = get_library_(desc);
if (cache) {
library_map_.insert({name, mtl_lib});
}
return mtl_lib;
}
MTL::Function* Device::get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
return get_function_(base_name, specialized_name, func_consts, mtl_lib);
}
MTL::Function* Device::get_function(
const std::string& base_name,
const std::string& lib_name /* = "mlx" */,
const std::string& specialized_name /* = "" */,
const MTLFCList& func_consts /* = {} */) {
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
return get_function(base_name, mtl_lib, specialized_name, func_consts);
}
MTL::LinkedFunctions* Device::get_linked_functions_(
const std::vector<MTL::Function*>& funcs) {
if (funcs.empty()) {
@@ -547,34 +533,55 @@ MTL::LinkedFunctions* Device::get_linked_functions_(
return lfuncs;
}
MTL::ComputePipelineState* Device::get_kernel_(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Single writer allowed
std::unique_lock wlock(kernel_mtx_);
// Try loading again to avoid loading twice
if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
return it->second;
}
auto pool = new_scoped_memory_pool();
// Pull kernel from library
auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
auto inserted = kernel_map_.insert({hash_name, kernel});
return kernel;
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
auto pool = new_scoped_memory_pool();
// Look for cached kernel
const auto& kname = hash_name.empty() ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
{
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Pull kernel from library
auto mtl_function = get_function_(base_name, kname, func_consts, mtl_lib);
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
return kernel;
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
}
MTL::ComputePipelineState* Device::get_kernel(
@@ -583,16 +590,34 @@ MTL::ComputePipelineState* Device::get_kernel(
const std::string& hash_name /* = "" */,
const MTLFCList& func_consts /* = {} */,
const std::vector<MTL::Function*>& linked_functions /* = {} */) {
// Look for cached kernel
const auto& kname = hash_name.size() == 0 ? base_name : hash_name;
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
{
// Multiple readers allowed
std::shared_lock lock(kernel_mtx_);
// Look for cached kernel
if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
return it->second;
}
}
// Search for cached metal lib
MTL::Library* mtl_lib = get_library_cache_(lib_name);
MTL::Library* mtl_lib = get_library_(lib_name);
return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
}
return get_kernel(base_name, mtl_lib, kname, func_consts, linked_functions);
void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
if (residency_set_ != nullptr) {
throw std::runtime_error(
"[Device::set_residency_set] Can only be set once.");
}
if (residency_set == nullptr) {
return;
}
residency_set_ = residency_set;
// Attach residency set to existing command queues
for (auto& [_, stream] : stream_map_) {
stream.queue->addResidencySet(residency_set_);
}
}
Device& device(mlx::core::Device) {

View File

@@ -7,6 +7,7 @@
#include <filesystem>
#include <functional>
#include <mutex>
#include <shared_mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
@@ -44,13 +45,13 @@ struct CommandEncoder {
struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
enc.concurrent = true;
enc.concurrent_ = true;
}
~ConcurrentContext() {
enc.concurrent = false;
enc.outputs.insert(
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
enc.concurrent_outputs.clear();
enc.concurrent_ = false;
enc.prev_outputs_.insert(
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
enc.concurrent_outputs_.clear();
}
private:
@@ -58,29 +59,73 @@ struct CommandEncoder {
};
MTL::ComputeCommandEncoder* operator->() {
return enc;
return enc_;
}
void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier();
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
~CommandEncoder();
private:
void maybe_split();
// Inputs to all kernels in the encoder including temporaries
std::unordered_set<const void*>& inputs() {
return all_inputs_;
};
int num_dispatches{0};
MTL::CommandBuffer* cbuf;
MTL::ComputeCommandEncoder* enc;
bool concurrent{false};
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs;
// Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*> outputs() {
return all_outputs_;
};
private:
MTL::ComputeCommandEncoder* enc_;
bool needs_barrier_{false};
bool concurrent_{false};
std::unordered_set<MTL::Resource*> prev_outputs_;
std::unordered_set<MTL::Resource*> next_outputs_;
std::unordered_set<MTL::Resource*> concurrent_outputs_;
std::unordered_set<const void*> all_inputs_;
std::unordered_set<const void*> all_outputs_;
};
struct Fence {
Fence(MTL::Fence* fence) : fence(fence) {}
~Fence() {
fence->release();
}
MTL::Fence* fence;
};
struct DeviceStream {
DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
~DeviceStream() {
queue->release();
if (buffer != nullptr) {
buffer->release();
}
};
MTL::CommandQueue* queue;
// A map of prior command encoder outputs to their corresponding fence
std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
// Used to allow thread-safe access to the outputs map
std::mutex fence_mtx;
// The buffer and buffer op count are updated
// between command buffers
MTL::CommandBuffer* buffer{nullptr};
int buffer_ops{0};
// The command encoder, fence, and temporaries are updated between command
// encoders
std::unique_ptr<CommandEncoder> encoder{nullptr};
std::shared_ptr<Fence> fence;
std::vector<array> temporaries;
};
class Device {
@@ -94,6 +139,10 @@ class Device {
return device_;
};
const std::string& get_architecture() {
return arch_;
}
void new_queue(int index);
MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index);
@@ -114,29 +163,9 @@ class Device {
}
}
MTL::Library* get_library(const std::string& name);
MTL::Library* get_library(
const std::string& name,
const std::string& source_string,
bool cache = true);
MTL::Library* get_library(
const std::string& name,
const MTL::StitchedLibraryDescriptor* desc,
bool cache = true);
MTL::Function* get_function(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
MTL::Function* get_function(
const std::string& base_name,
const std::string& lib_name = "mlx",
const std::string& specialized_name = "",
const MTLFCList& func_consts = {});
const std::function<std::string(void)>& builder);
MTL::ComputePipelineState* get_kernel(
const std::string& base_name,
@@ -155,11 +184,20 @@ class Device {
MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
// Record temporary arrays for the given stream index
void add_temporary(array arr, int index);
void add_temporaries(std::vector<array> arrays, int index);
void set_residency_set(const MTL::ResidencySet* residency_set);
private:
DeviceStream& get_stream_(int index) {
return stream_map_.find(index)->second;
}
MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& source_string);
MTL::Library* get_library_(const MTL::StitchedLibraryDescriptor* desc);
MTL::Library* get_library_(const std::string& name);
MTL::Library* build_library_(const std::string& source_string);
MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib);
@@ -181,13 +219,23 @@ class Device {
const MTL::Function* mtl_function,
const MTL::LinkedFunctions* linked_functions);
MTL::ComputePipelineState* get_kernel_(
const std::string& base_name,
MTL::Library* mtl_lib,
const std::string& hash_name,
const MTLFCList& func_consts = {},
const std::vector<MTL::Function*>& linked_functions = {});
MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
std::unordered_map<int32_t, DeviceStream> stream_map_;
std::shared_mutex kernel_mtx_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
std::shared_mutex library_mtx_;
std::unordered_map<std::string, MTL::Library*> library_map_;
std::mutex mtx_;
const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_;
};
Device& device(mlx::core::Device);

View File

@@ -575,10 +575,7 @@ void fft_op(
auto plan = plan_fft(n);
if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -744,12 +741,7 @@ void fft_op(
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
void fft_op(
@@ -792,8 +784,7 @@ void nd_fft_op(
}
auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler(
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
d.add_temporaries(std::move(temp_arrs), s.index);
}
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -60,32 +60,6 @@ std::string gen_hadamard_codelet(int m) {
return source.str();
}
void launch_hadamard(
const array& in,
array& out,
int batch_size,
int threads_per,
const std::string kernel_name,
float scale,
const Stream& s) {
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
@@ -113,7 +87,8 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto [n, m] = decompose_hadamard(in.shape(axis));
int n, m;
std::tie(n, m) = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
@@ -129,8 +104,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
@@ -148,12 +122,31 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
n,
m,
read_width);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
int batch_size = in.size() / n;
int threads_per = n / max_radix;
auto& compute_encoder = d.get_command_encoder(s.index);
auto launch_hadamard = [&](const array& in,
array& out,
const std::string& kernel_name,
float scale) {
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
};
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
@@ -171,37 +164,17 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(
in_contiguous,
temp,
batch_size,
threads_per,
"n" + kernel_name,
1.0,
s);
launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
launch_hadamard(temp, out, "m" + kernel_name, scale_);
} else {
launch_hadamard(
in_contiguous,
out,
batch_size,
threads_per,
"n" + kernel_name,
scale_,
s);
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -64,8 +64,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_name = lib_name;
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather();
std::string out_type_str = get_type_string(out.dtype());
@@ -83,8 +82,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args,
idx_arr,
idx_ndim);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
@@ -114,17 +113,17 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
std::vector<char> idx_contigs;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
}
// Set all the buffers
@@ -132,21 +131,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 1);
// Set source info
compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2);
compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3);
set_vector_bytes(compute_encoder, src.shape(), 2);
set_vector_bytes(compute_encoder, src.strides(), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4);
compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6);
set_vector_bytes(compute_encoder, slice_sizes_, 5);
set_vector_bytes(compute_encoder, axes_, 6);
// Set index info
//
// We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 7);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 8);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, idx_shapes, 7);
set_vector_bytes(compute_encoder, idx_strides, 8);
set_vector_bytes(compute_encoder, idx_contigs, 9);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
@@ -173,12 +171,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Copy src into out
auto copy_type =
inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General;
CopyType copy_type;
if (inputs[0].data_size() == 1) {
copy_type = CopyType::Scalar;
} else if (inputs[0].flags().row_contiguous) {
copy_type = CopyType::Vector;
} else {
copy_type = CopyType::General;
}
copy_gpu(inputs[0], out, copy_type);
auto& upd = inputs.back();
// Empty update
if (inputs.back().size() == 0) {
if (upd.size() == 0) {
return;
}
@@ -187,19 +193,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1);
size_t idx_size = nidx ? inputs[1].size() : 1;
// Bail from fast path (1d index specialization) if scatter dims aren't
// the outermost dims and contiguous since update access won't be raster
// order.
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= (axes_[i] == i);
}
// Bail from fast path (1d index specialization) if any of the dims are
// broadcasted, since we can't rely on linear indexing in that case.
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= inputs[i].flags().row_contiguous;
auto idx_to_out = idx_size / out.size();
int nwork;
if (idx_ndim <= 1 || idx_to_out < 1) {
nwork = 1;
} else if (idx_to_out <= 4) {
nwork = 4;
} else if (idx_to_out < 16) {
nwork = 8;
} else if (idx_to_out < 32) {
nwork = 16;
} else {
nwork = 32;
}
std::string lib_name;
@@ -223,21 +230,16 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op_name = "min";
break;
}
auto upd_contig = upd.flags().row_contiguous;
{
std::ostringstream kname;
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
kname << "_" << op_name << "_" << nidx;
kname << "scatter" << type_to_name(out) << idx_type_name;
kname << "_" << op_name << "_" << nidx << "_"
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork;
lib_name = kname.str();
kernel_name = kname.str();
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
@@ -264,7 +266,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break;
}
if (reduce_type_ != Scatter::None) {
op_type = fmt::format(op_type, out_type_str);
op_type = fmt::format(fmt::runtime(op_type), out_type_str);
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
@@ -276,14 +278,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
op_type,
nidx,
idx_args,
idx_arr);
lib = d.get_library(lib_name, kernel_source.str());
}
idx_arr,
upd_contig,
nwork);
return kernel_source.str();
});
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib);
auto& upd = inputs.back();
size_t nthreads = upd.size();
compute_encoder->setComputePipelineState(kernel);
@@ -293,109 +296,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2);
// Set update info
uint upd_ndim = upd.ndim();
size_t upd_ndim = upd.ndim();
size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i);
}
if (index_nd1_specialization) {
compute_encoder->setBytes(
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
size_t out_ndim = out.ndim();
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
if (upd_ndim <= 1) {
// Placeholder so Metal doesn't compalain
int shape_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 6);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes(
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes(
out.strides().data(), out_ndim * sizeof(size_t), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Collect all idx shapes and strides into one place
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
// To access .data() use char instead of bool
// bool is 1 byte in Metal so this is safe
std::vector<char> idx_contigs;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
set_vector_bytes(compute_encoder, upd.shape(), 3);
set_vector_bytes(compute_encoder, upd.strides(), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
set_vector_bytes(compute_encoder, out.shape(), 7);
set_vector_bytes(compute_encoder, out.strides(), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
idx_contigs.push_back(false);
}
set_vector_bytes(compute_encoder, idx_shapes, 11);
set_vector_bytes(compute_encoder, idx_strides, 12);
set_vector_bytes(compute_encoder, idx_contigs, 13);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14);
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15);
// Set index buffers
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
auto grid_y = (nthreads / upd_size);
grid_y = (grid_y + nwork - 1) / nwork;
MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
}
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} // namespace mlx::core

View File

@@ -11,12 +11,13 @@ constexpr std::string_view gather_kernels = R"(
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant size_t* idx_strides [[buffer(8)]],
const constant int& idx_ndim [[buffer(9)]],
const constant bool* idx_contigs [[buffer(9)]],
const constant int& idx_ndim [[buffer(10)]],
{4}
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {{
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}>(
src,
@@ -33,32 +34,7 @@ constexpr std::string_view gather_kernels = R"(
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter_1d_index{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates,
out,
out_shape,
out_strides,
out_ndim,
upd_shape,
upd_ndim,
upd_size,
idx_buffers,
gid);
}}
[[kernel]] void scatter{0}_{4}(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
@@ -71,12 +47,14 @@ constexpr std::string_view scatter_kernels = R"(
const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]],
const constant size_t* idx_strides [[buffer(12)]],
const constant int& idx_ndim [[buffer(13)]],
const constant bool* idx_contigs [[buffer(13)]],
const constant int& idx_ndim [[buffer(14)]],
const constant size_t& idx_size [[buffer(15)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}>(
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>(
updates,
out,
upd_shape,
@@ -87,6 +65,7 @@ constexpr std::string_view scatter_kernels = R"(
out_strides,
out_ndim,
axes,
idx_size,
idxs,
gid);
}}

View File

@@ -1,26 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view scan_kernels = R"(
template [[host_name("contig_{0}")]] [[kernel]] void
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("strided_{0}")]] [[kernel]] void
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[thread_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]);
)";

View File

@@ -4,7 +4,6 @@
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
@@ -25,38 +24,38 @@ MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
kernel_source
<< metal::utils() << metal::arange()
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source << metal::utils() << metal::arange()
<< fmt::format(
arange_kernels,
kernel_name,
get_type_string(out.dtype()));
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary();
kernel_source << get_template_definition(
"v_" + lib_name, "unary_v", get_type_string(out_type), op);
"v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source << get_template_definition(
"v2_" + lib_name, "unary_v2", get_type_string(out_type), op);
"v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source << get_template_definition(
"g_" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << get_template_definition(
"gn4_" + lib_name, "unary_g", get_type_string(out_type), op, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -66,7 +65,7 @@ void add_binary_kernels(
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
const std::array<std::pair<std::string, std::string>, 11> kernel_types = {{
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
@@ -77,7 +76,6 @@ void add_binary_kernels(
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"gn", "binary_g"},
}};
for (auto& [name, func] : kernel_types) {
std::string template_def;
@@ -105,13 +103,12 @@ MTL::ComputePipelineState* get_binary_kernel(
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -122,14 +119,13 @@ MTL::ComputePipelineState* get_binary_two_kernel(
Dtype out_type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -139,13 +135,11 @@ MTL::ComputePipelineState* get_ternary_kernel(
Dtype type,
const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
const std::array<std::pair<std::string, std::string>, 6> kernel_types = {{
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
@@ -159,8 +153,8 @@ MTL::ComputePipelineState* get_ternary_kernel(
}
kernel_source << get_template_definition(
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -170,36 +164,33 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source
<< metal::utils() << metal::copy()
<< get_template_definition("s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition("v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition("g_" + lib_name, "copy_g", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"gg_" + lib_name, "copy_gg", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source << metal::utils() << metal::copy()
<< get_template_definition(
"s_" + lib_name, "copy_s", in_type, out_type)
<< get_template_definition(
"v_" + lib_name, "copy_v", in_type, out_type)
<< get_template_definition(
"g1_" + lib_name, "copy_g_nd1", in_type, out_type)
<< get_template_definition(
"g2_" + lib_name, "copy_g_nd2", in_type, out_type)
<< get_template_definition(
"g3_" + lib_name, "copy_g_nd3", in_type, out_type)
<< get_template_definition(
"gn4_" + lib_name, "copy_g", in_type, out_type, 4)
<< get_template_definition(
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type)
<< get_template_definition(
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type)
<< get_template_definition(
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type)
<< get_template_definition(
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -209,8 +200,7 @@ MTL::ComputePipelineState* get_softmax_kernel(
bool precise,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&] {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
@@ -218,8 +208,8 @@ MTL::ComputePipelineState* get_softmax_kernel(
lib_name,
get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -232,22 +222,29 @@ MTL::ComputePipelineState* get_scan_kernel(
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_name = "Cum" + reduce_type;
op_name[3] = toupper(op_name[3]);
auto lib = d.get_library(lib_name, [&]() {
auto out_type = get_type_string(out.dtype());
std::string op = "Cum" + reduce_type + "<" + out_type + ">";
op[3] = toupper(op[3]);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
scan_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name,
inclusive,
reverse);
lib = d.get_library(lib_name, kernel_source.str());
}
kernel_source << metal::utils() << metal::scan();
const std::array<std::pair<std::string, std::string>, 2> scan_kernels = {{
{"contig_", "contiguous_scan"},
{"strided_", "strided_scan"},
}};
for (auto& [prefix, kernel] : scan_kernels) {
kernel_source << get_template_definition(
prefix + lib_name,
kernel,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op,
in.itemsize() <= 4 ? 4 : 2,
inclusive,
reverse);
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -259,8 +256,7 @@ MTL::ComputePipelineState* get_sort_kernel(
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
@@ -285,8 +281,8 @@ MTL::ComputePipelineState* get_sort_kernel(
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -298,8 +294,7 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort();
std::array<std::pair<std::string, std::string>, 3> kernel_types = {
@@ -316,27 +311,28 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& out) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
std::string op_type = op_name(out);
op_type[0] = std::toupper(op_name(out)[0]);
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
auto out_type = get_type_string(out.dtype());
std::string op = op_type + "<" + out_type + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce();
kernel_source << get_template_definition(
kernel_name, "init_reduce", out_type, op);
lib = d.get_library(kernel_name, kernel_source.str());
}
kernel_name, func_name, out_type, op);
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -350,8 +346,7 @@ MTL::ComputePipelineState* get_reduce_kernel(
int ndim /* = -1 */,
int bm /* = -1 */,
int bn /* = -1 */) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
auto lib = d.get_library(kernel_name, [&]() {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
std::ostringstream kernel_source;
@@ -369,8 +364,8 @@ MTL::ComputePipelineState* get_reduce_kernel(
kernel_source << get_template_definition(
kernel_name, func_name, in_type, out_type, op);
}
lib = d.get_library(kernel_name, kernel_source.str());
}
return kernel_source.str();
});
auto st = d.get_kernel(kernel_name, lib);
return st;
}
@@ -389,8 +384,7 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused()
@@ -405,8 +399,8 @@ MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
@@ -425,8 +419,7 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
@@ -444,8 +437,8 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -456,19 +449,19 @@ MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
const array& out,
bool axbpy) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels,
fmt::runtime(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels),
"name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -488,8 +481,7 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
@@ -513,8 +505,8 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -533,8 +525,7 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
int tn,
bool contiguous) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
@@ -556,8 +547,8 @@ MTL::ComputePipelineState* get_gemv_masked_kernel(
"tn"_a = tn,
"trans"_a = transpose_mat ? "t_" : "",
"nc"_a = contiguous ? "0" : "1");
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -573,8 +564,7 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
int n_channel_specialization,
bool small_filter) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format(
@@ -588,8 +578,8 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
"wn"_a = wn,
"n_channels"_a = n_channel_specialization,
"small_filter"_a = small_filter);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -603,8 +593,7 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general()
@@ -617,8 +606,8 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn);
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}
@@ -629,13 +618,12 @@ MTL::ComputePipelineState* get_fft_kernel(
const metal::MTLFCList& func_consts,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
std::string kernel_string;
kernel_source << metal::fft() << template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
@@ -644,13 +632,12 @@ MTL::ComputePipelineState* get_quantized_kernel(
const std::string& kernel_name,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return kernel_source.str();
});
return d.get_kernel(kernel_name, lib);
}

View File

@@ -15,6 +15,7 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
@@ -78,6 +79,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& func_name,
const std::string& op_name,
const array& out);
MTL::ComputePipelineState* get_reduce_kernel(
@@ -208,10 +211,10 @@ get_template_definition(std::string name, std::string func, Args... args) {
};
(add_arg(args), ...);
s << ">";
std::string base_string = R"(
template [[host_name("{0}")]] [[kernel]] decltype({1}) {1};
)";
return fmt::format(base_string, name, s.str());
return fmt::format(
"\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
name,
s.str());
}
} // namespace mlx::core

View File

@@ -30,8 +30,9 @@ build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(scaled_dot_product_attention scaled_dot_product_attention_params.h
steel/defines.h steel/gemm/transforms.h steel/utils.h)
build_kernel(
scaled_dot_product_attention scaled_dot_product_attention_params.h
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
set(STEEL_HEADERS
steel/defines.h
@@ -49,7 +50,9 @@ set(STEEL_HEADERS
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h)
steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h
steel/utils/integral_constant.h)
if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h)

View File

@@ -17,7 +17,6 @@
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \

View File

@@ -15,7 +15,6 @@
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn_" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \

View File

@@ -16,9 +16,7 @@
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype) \
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_itype(itname, itype) \

View File

@@ -25,11 +25,13 @@ METAL_FUNC void gather_impl(
idx_loc = index.x * indices.strides[indices.ndim * i];
} else {
idx_loc = index.x * indices.strides[indices.ndim * i];
idx_loc += elem_to_loc(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
indices.ndim - 1);
idx_loc += indices.row_contiguous[i]
? index.y
: elem_to_loc(
index.y,
&indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1],
indices.ndim - 1);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);

View File

@@ -9,6 +9,7 @@ struct Indices {
const array<const device IdxT*, NIDX> buffers;
const constant int* shapes;
const constant size_t* strides;
const constant bool* row_contiguous;
const int ndim;
};

View File

@@ -8,6 +8,7 @@ using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
MLX_MTL_CONST int SIMD_SIZE = 32;
MLX_MTL_CONST int QUAD_SIZE = 4;
template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) {
@@ -371,6 +372,64 @@ struct QuantizedBlockLoader {
}
};
template <typename T, int group_size, int bits, int D>
METAL_FUNC void qmv_quad_impl(
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
constexpr int pack_factor = 32 / bits;
constexpr int values_per_thread = D / QUAD_SIZE;
constexpr int packs_per_thread = values_per_thread / pack_factor;
constexpr int scale_step_per_thread = group_size / values_per_thread;
constexpr int results_per_quadgroup = 8;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_quadgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size / pack_factor;
const int in_vec_size_g = in_vec_size / group_size;
const int out_row = tid.x * quads_per_simd * results_per_quadgroup + quad_gid;
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
x += tid.y * in_vec_size + quad_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
for (int row = 0; row < results_per_quadgroup; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
U s = sl[0];
U b = bl[0];
if (row * quads_per_simd + out_row < out_vec_size) {
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
}
}
for (int row = 0; row < results_per_quadgroup; row++) {
result[row] = quad_sum(result[row]);
if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
y[row * quads_per_simd] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits>
METAL_FUNC void qmv_fast_impl(
const device uint32_t* w,
@@ -586,13 +645,13 @@ METAL_FUNC void qmv_impl(
template <typename T, const int group_size, const int bits>
METAL_FUNC void qvm_impl(
const device T* x,
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
const int in_vec_size,
const int out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -697,16 +756,16 @@ template <
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_t_impl(
const device T* x,
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
threadgroup T* Xs,
threadgroup T* Ws,
const constant int& M,
const constant int& N,
const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -818,16 +877,16 @@ template <
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_n_impl(
const device T* x,
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
threadgroup T* Xs,
threadgroup T* Ws,
const constant int& M,
const constant int& N,
const constant int& K,
const constant int& N,
const constant int& M,
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -942,6 +1001,45 @@ METAL_FUNC void qmm_n_impl(
}
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
const device uint32_t*& w,
const device T*& scales,
const device T*& biases,
device T*& y,
int output_stride,
const constant int& x_batch_ndims,
const constant int* x_shape,
const constant size_t* x_strides,
const constant int& w_batch_ndims,
const constant int* w_shape,
const constant size_t* w_strides,
const constant size_t* s_strides,
const constant size_t* b_strides,
uint3 tid [[threadgroup_position_in_grid]]) {
// Set the input/output matrices
uint32_t x_idx = tid.z;
uint32_t w_idx = tid.z;
if (x_batch_ndims == 1) {
x += x_idx * x_strides[0];
} else {
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
}
if (w_batch_ndims == 1) {
w += w_idx * w_strides[0];
scales += w_idx * s_strides[0];
biases += w_idx * b_strides[0];
} else {
ulong3 idx = elem_to_loc_broadcast(
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
w += idx.x;
scales += idx.y;
biases += idx.z;
}
y += tid.z * output_stride;
}
template <typename T>
METAL_FUNC void adjust_matrix_offsets(
const device T*& x,
@@ -996,7 +1094,58 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, int D, bool batched>
[[kernel]] void qmv_quad(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint quad_gid [[quadgroup_index_in_threadgroup]],
uint quad_lid [[thread_index_in_quadgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmv_quad_impl<T, group_size, bits, D>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
quad_gid,
quad_lid);
}
template <typename T, int group_size, int bits, bool batched>
[[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1005,9 +1154,35 @@ template <typename T, int group_size, int bits>
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmv_fast_impl<T, group_size, bits>(
w,
scales,
@@ -1021,7 +1196,7 @@ template <typename T, int group_size, int bits>
simd_lid);
}
template <typename T, const int group_size, const int bits>
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1030,9 +1205,35 @@ template <typename T, const int group_size, const int bits>
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmv_impl<T, group_size, bits>(
w,
scales,
@@ -1046,25 +1247,106 @@ template <typename T, const int group_size, const int bits>
simd_lid);
}
template <typename T, const int group_size, const int bits>
template <typename T, const int group_size, const int bits, bool batched>
[[kernel]] void qvm(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qvm_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
template <typename T, const int group_size, const int bits, int split_k = 32>
[[kernel]] void qvm_split_k(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int& final_block_size [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
in_vec_size,
out_vec_size,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
// When (in_vec_size % split_k != 0) the final block needs to be smaller
int in_vec_size_adj =
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
qvm_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size_adj,
out_vec_size,
tid,
simd_gid,
@@ -1076,18 +1358,27 @@ template <
const int group_size,
const int bits,
const bool aligned_N,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1099,26 +1390,53 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <
typename T,
const int group_size,
const int bits,
const bool batched,
const int BM = 32,
const int BK = 32,
const int BN = 32>
[[kernel]] void qmm_n(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& M [[buffer(5)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& K [[buffer(7)]],
const constant int& M [[buffer(7)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1131,8 +1449,27 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
if (batched) {
adjust_matrix_offsets<T>(
x,
w,
scales,
biases,
y,
M * N,
x_batch_ndims,
x_shape,
x_strides,
w_batch_ndims,
w_shape,
w_strides,
s_strides,
b_strides,
tid);
}
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits>
@@ -1141,23 +1478,23 @@ template <typename T, int group_size, int bits>
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]],
const constant int* x_shape [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]],
const constant int* w_shape [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1202,23 +1539,23 @@ template <typename T, int group_size, int bits>
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]],
const constant int* x_shape [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]],
const constant int* w_shape [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1259,27 +1596,27 @@ template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits>
[[kernel]] void bs_qvm(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const constant int& in_vec_size [[buffer(7)]],
const constant int& out_vec_size [[buffer(8)]],
const constant int& batch_ndims [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* lhs_strides [[buffer(11)]],
const constant size_t* rhs_strides [[buffer(12)]],
const constant int& x_batch_ndims [[buffer(13)]],
const constant int* x_shape [[buffer(14)]],
const constant size_t* x_strides [[buffer(15)]],
const constant int& w_batch_ndims [[buffer(16)]],
const constant int* w_shape [[buffer(17)]],
const constant size_t* w_strides [[buffer(18)]],
const constant size_t* s_strides [[buffer(19)]],
const constant size_t* b_strides [[buffer(20)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
const constant int& x_batch_ndims [[buffer(7)]],
const constant int* x_shape [[buffer(8)]],
const constant size_t* x_strides [[buffer(9)]],
const constant int& w_batch_ndims [[buffer(10)]],
const constant int* w_shape [[buffer(11)]],
const constant size_t* w_strides [[buffer(12)]],
const constant size_t* s_strides [[buffer(13)]],
const constant size_t* b_strides [[buffer(14)]],
const constant int& batch_ndims [[buffer(15)]],
const constant int* batch_shape [[buffer(16)]],
const device uint32_t* lhs_indices [[buffer(17)]],
const device uint32_t* rhs_indices [[buffer(18)]],
const constant size_t* lhs_strides [[buffer(19)]],
const constant size_t* rhs_strides [[buffer(20)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
@@ -1306,10 +1643,10 @@ template <typename T, int group_size, int bits>
b_strides,
tid);
qvm_impl<T, group_size, bits>(
x,
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
@@ -1327,28 +1664,28 @@ template <
const int BK = 32,
const int BN = 32>
[[kernel]] void bs_qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& K [[buffer(9)]],
const constant int& batch_ndims [[buffer(10)]],
const constant int* batch_shape [[buffer(11)]],
const constant size_t* lhs_strides [[buffer(12)]],
const constant size_t* rhs_strides [[buffer(13)]],
const constant int& x_batch_ndims [[buffer(14)]],
const constant int* x_shape [[buffer(15)]],
const constant size_t* x_strides [[buffer(16)]],
const constant int& w_batch_ndims [[buffer(17)]],
const constant int* w_shape [[buffer(18)]],
const constant size_t* w_strides [[buffer(19)]],
const constant size_t* s_strides [[buffer(20)]],
const constant size_t* b_strides [[buffer(21)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const constant size_t* lhs_strides [[buffer(20)]],
const constant size_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1383,7 +1720,7 @@ template <
b_strides,
tid);
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <
@@ -1394,28 +1731,28 @@ template <
const int BK = 32,
const int BN = 32>
[[kernel]] void bs_qmm_n(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
const device T* scales [[buffer(2)]],
const device T* biases [[buffer(3)]],
const device uint32_t* lhs_indices [[buffer(4)]],
const device uint32_t* rhs_indices [[buffer(5)]],
device T* y [[buffer(6)]],
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
const device T* x [[buffer(3)]],
device T* y [[buffer(4)]],
const constant int& K [[buffer(5)]],
const constant int& N [[buffer(6)]],
const constant int& M [[buffer(7)]],
const constant int& N [[buffer(8)]],
const constant int& K [[buffer(9)]],
const constant int& batch_ndims [[buffer(10)]],
const constant int* batch_shape [[buffer(11)]],
const constant size_t* lhs_strides [[buffer(12)]],
const constant size_t* rhs_strides [[buffer(13)]],
const constant int& x_batch_ndims [[buffer(14)]],
const constant int* x_shape [[buffer(15)]],
const constant size_t* x_strides [[buffer(16)]],
const constant int& w_batch_ndims [[buffer(17)]],
const constant int* w_shape [[buffer(18)]],
const constant size_t* w_strides [[buffer(19)]],
const constant size_t* s_strides [[buffer(20)]],
const constant size_t* b_strides [[buffer(21)]],
const constant int& x_batch_ndims [[buffer(8)]],
const constant int* x_shape [[buffer(9)]],
const constant size_t* x_strides [[buffer(10)]],
const constant int& w_batch_ndims [[buffer(11)]],
const constant int* w_shape [[buffer(12)]],
const constant size_t* w_strides [[buffer(13)]],
const constant size_t* s_strides [[buffer(14)]],
const constant size_t* b_strides [[buffer(15)]],
const constant int& batch_ndims [[buffer(16)]],
const constant int* batch_shape [[buffer(17)]],
const device uint32_t* lhs_indices [[buffer(18)]],
const device uint32_t* rhs_indices [[buffer(19)]],
const constant size_t* lhs_strides [[buffer(20)]],
const constant size_t* rhs_strides [[buffer(21)]],
uint3 tid [[threadgroup_position_in_grid]],
uint lid [[thread_index_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -1451,7 +1788,7 @@ template <
b_strides,
tid);
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits>

View File

@@ -5,67 +5,118 @@
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_quantized(name, type, group_size, bits) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
#define instantiate_quantized(name, type, group_size, bits) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits)
#define instantiate_quantized_types(name, group_size, bits) \
instantiate_quantized(name, float, group_size, bits) \
instantiate_quantized(name, float16_t, group_size, bits) \
instantiate_quantized(name, bfloat16_t, group_size, bits)
#define instantiate_quantized_batched(name, type, group_size, bits, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
batched)
#define instantiate_quantized_groups(name, bits) \
instantiate_quantized_types(name, 128, bits) \
instantiate_quantized_types(name, 64, bits) \
instantiate_quantized_types(name, 32, bits)
#define instantiate_quantized_all(name) \
instantiate_quantized_groups(name, 2) \
instantiate_quantized_groups(name, 4) \
instantiate_quantized_groups(name, 8)
instantiate_quantized_all(qmv_fast)
instantiate_quantized_all(qmv)
instantiate_quantized_all(qvm)
instantiate_quantized_all(qmm_n)
instantiate_quantized_all(bs_qmv_fast)
instantiate_quantized_all(bs_qmv)
instantiate_quantized_all(bs_qvm)
instantiate_quantized_all(bs_qmm_n)
instantiate_quantized_all(affine_quantize)
instantiate_quantized_all(affine_quantize_scales_biases)
instantiate_quantized_all(affine_dequantize)
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
name, \
type, \
group_size, \
bits, \
aligned)
#define instantiate_quantized_types_aligned(name, group_size, bits) \
instantiate_quantized_aligned(name, float, group_size, bits, true) \
instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \
instantiate_quantized_aligned(name, float, group_size, bits, false) \
instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false)
#define instantiate_quantized_aligned_batched(name, type, group_size, bits, aligned, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
aligned, \
batched)
#define instantiate_quantized_groups_aligned(name, bits) \
instantiate_quantized_types_aligned(name, 128, bits) \
instantiate_quantized_types_aligned(name, 64, bits) \
instantiate_quantized_types_aligned(name, 32, bits)
#define instantiate_quantized_quad(name, type, group_size, bits, D, batched) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_d_" #D "_batch_" #batched, \
name, \
type, \
group_size, \
bits, \
D, \
batched)
#define instantiate_quantized_all_aligned(name) \
instantiate_quantized_groups_aligned(name, 2) \
instantiate_quantized_groups_aligned(name, 4) \
instantiate_quantized_groups_aligned(name, 8) \
#define instantiate_quantized_split_k(name, type, group_size, bits, split_k) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_spk_" #split_k, \
name, \
type, \
group_size, \
bits, \
split_k)
instantiate_quantized_all_aligned(qmm_t)
instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0)
#define instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
#define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \
instantiate_quantized(bs_qvm, type, group_size, bits) \
instantiate_quantized(bs_qmm_n, type, group_size, bits)
#define instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, true) \
instantiate_quantized_aligned(bs_qmm_t, type, group_size, bits, false) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, true, 0) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 1) \
instantiate_quantized_aligned_batched(qmm_t, type, group_size, bits, false, 0)
#define instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 64, 0) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 1) \
instantiate_quantized_quad(qmv_quad, type, group_size, bits, 128, 0)
#define instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \
instantiate_quantized_funcs(float16_t, group_size, bits) \
instantiate_quantized_funcs(bfloat16_t, group_size, bits)
#define instantiate_quantized_groups(bits) \
instantiate_quantized_types(128, bits) \
instantiate_quantized_types(64, bits) \
instantiate_quantized_types(32, bits)
#define instantiate_quantized_all() \
instantiate_quantized_groups(2) \
instantiate_quantized_groups(4) \
instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on

View File

@@ -34,8 +34,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
[[kernel]] void rbitsc(
device const uint32_t* keys,
device char* out,
device const bool& odd,
device const uint& bytes_per_key,
constant const bool& odd,
constant const uint& bytes_per_key,
uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x;
@@ -67,8 +67,8 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
[[kernel]] void rbits(
device const uint32_t* keys,
device char* out,
device const bool& odd,
device const uint& bytes_per_key,
constant const bool& odd,
constant const uint& bytes_per_key,
constant const int& ndim,
constant const int* key_shape,
constant const size_t* key_strides,

View File

@@ -113,9 +113,12 @@ instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
@@ -123,9 +126,14 @@ instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 8, 128) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32)
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \

View File

@@ -1,11 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
template <
typename T,
typename U,
typename Op,
int NDIMS,
int N_READS = REDUCE_N_READS>
template <typename T, typename U, typename Op, int NDIMS>
[[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
@@ -20,170 +15,128 @@ template <
const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[thread_position_in_grid]],
uint3 tsize [[threads_per_grid]]) {
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
constexpr int n_reads = 4;
Op op;
looped_elem_to_loc<NDIMS> loop;
const device T* row;
// Case 1: Small row small column
if (reduction_size * non_col_reductions < 64 && reduction_stride < 32) {
U totals[31];
for (int i = 0; i < 31; i++) {
totals[i] = Op::init;
U totals[n_reads];
for (int i = 0; i < n_reads; i++) {
totals[i] = Op::init;
}
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;
if (column >= reduction_stride) {
return;
}
bool safe = column + n_reads <= reduction_stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
size_t total_rows = non_col_reductions * reduction_size;
loop.next(lid.y, reduce_shape, reduce_strides);
for (size_t r = lid.y; r < total_rows; r += lsize.y) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(lsize.y, reduce_shape, reduce_strides);
}
short stride = reduction_stride;
short size = reduction_size;
short blocks = stride / N_READS;
short extra = stride - blocks * N_READS;
size_t out_idx = tid.x + tsize.y * size_t(tid.y);
in += elem_to_loc(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
for (short i = 0; i < size; i++) {
for (short j = 0; j < blocks; j++) {
for (short k = 0; k < N_READS; k++) {
totals[j * N_READS + k] =
op(totals[j * N_READS + k],
static_cast<U>(row[i * stride + j * N_READS + k]));
}
}
for (short k = 0; k < extra; k++) {
totals[blocks * N_READS + k] =
op(totals[blocks * N_READS + k],
static_cast<U>(row[i * stride + blocks * N_READS + k]));
if (lsize.y > 1) {
// lsize.y should be <= 8
threadgroup U shared_vals[32 * 8 * n_reads];
for (int i = 0; i < n_reads; i++) {
shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (lid.y == 0) {
for (int i = 0; i < n_reads; i++) {
totals[i] = shared_vals[lid.x * n_reads + i];
}
for (uint j = 1; j < lsize.y; j++) {
for (int i = 0; i < n_reads; i++) {
totals[i] =
op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
totals[i]);
}
}
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride;
for (short j = 0; j < stride; j++) {
out[j] = totals[j];
}
}
// Case 2: Long row small column
else if (reduction_size * non_col_reductions < 32) {
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}
short size = reduction_size;
size_t offset = size_t(tid.x) * N_READS;
bool safe = offset + N_READS <= reduction_stride;
short extra = reduction_stride - offset;
size_t out_idx = tid.y + tsize.z * size_t(tid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset;
for (uint r = 0; r < non_col_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (short i = 0; i < size; i++) {
for (short j = 0; j < N_READS; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
} else {
for (short i = 0; i < size; i++) {
for (short j = 0; j < extra; j++) {
totals[j] =
op(static_cast<U>(row[i * reduction_stride + j]), totals[j]);
}
}
}
loop.next(reduce_shape, reduce_strides);
}
out += out_idx * reduction_stride + offset;
if (lid.y == 0) {
out += out_idx * reduction_stride + column;
if (safe) {
for (short i = 0; i < N_READS; i++) {
for (int i = 0; i < n_reads; i++) {
out[i] = totals[i];
}
} else {
for (short i = 0; i < extra; i++) {
for (int i = 0; column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}
// Case 3: Long row medium column
else {
threadgroup U shared_vals[1024];
U totals[N_READS];
for (int i = 0; i < N_READS; i++) {
totals[i] = Op::init;
}
short stride = reduction_stride;
short lid = simd_group_id * simd_size + simd_lane_id;
short2 tile((stride + N_READS - 1) / N_READS, 32);
short2 offset((lid % tile.x) * N_READS, lid / tile.x);
short sm_stride = tile.x * N_READS;
bool safe = offset.x + N_READS <= stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z);
in += elem_to_loc(out_idx, shape, strides, ndim) + offset.x;
// Read cooperatively and contiguously and aggregate the partial results.
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y, reduce_shape, reduce_strides);
for (size_t r = offset.y; r < total; r += simd_size) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < N_READS; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[N_READS];
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset.x + i < stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < N_READS; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(simd_size, reduce_shape, reduce_strides);
}
// Each thread holds N_READS partial results but the simdgroups are not
// aligned to do the reduction across the simdgroup so we write our results
// in the shared memory and read them back according to the simdgroup.
for (int i = 0; i < N_READS; i++) {
shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_READS; i++) {
totals[i] = op.simd_reduce(
shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);
}
// Write the output.
if (simd_lane_id == 0) {
short column = simd_group_id * N_READS;
out += out_idx * reduction_stride + column;
if (column + N_READS <= stride) {
for (int i = 0; i < N_READS; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; column + i < stride; i++) {
out[i] = totals[i];
}
}
template <typename T, typename U, typename Op, int NDIMS>
[[kernel]] void col_reduce_longcolumn(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
Op op;
looped_elem_to_loc<NDIMS> loop;
const device T* row;
size_t out_idx = gid.x + gsize.x * size_t(gid.y);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + lid.x;
U total = Op::init;
size_t total_rows = non_col_reductions * reduction_size;
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
for (size_t r = gid.z * lsize.y + lid.y; r < total_rows;
r += lsize.y * gsize.z) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
total = op(static_cast<U>(*row), total);
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
}
threadgroup U shared_vals[32 * 32];
shared_vals[lid.y * lsize.x + lid.x] = total;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (lid.y == 0) {
for (uint i = 1; i < lsize.y; i++) {
total = op(total, shared_vals[i * lsize.x + lid.x]);
}
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;
}
}
@@ -216,7 +169,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
constexpr int n_simdgroups = 4;
constexpr int n_simdgroups = 8;
constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads;
@@ -329,3 +282,103 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
}
}
}
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
[[kernel]] void col_reduce_2pass(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
constexpr int n_simdgroups = 8;
constexpr short tgp_size = n_simdgroups * simd_size;
constexpr short n_reads = (BM * BN) / tgp_size;
constexpr short n_read_blocks = BN / n_reads;
constexpr int n_outputs = BN / n_simdgroups;
constexpr short outer_blocks = 32;
static_assert(BM == 32, "BM should be equal to 32");
threadgroup U shared_vals[BN * BM];
U totals[n_reads];
looped_elem_to_loc<NDIMS> loop;
const device T* row;
for (int i = 0; i < n_reads; i++) {
totals[i] = Op::init;
}
short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride;
size_t full_idx = gid.y + gsize.y * size_t(gid.z);
size_t block_idx = full_idx / out_size;
size_t out_idx = full_idx % out_size;
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim);
in += in_idx + column;
size_t total = non_col_reductions * reduction_size;
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
for (size_t r = offset.y + block_idx * BM; r < total;
r += outer_blocks * BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) {
for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]);
}
} else {
U vals[n_reads];
for (int i = 0; i < n_reads; i++) {
vals[i] =
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
}
for (int i = 0; i < n_reads; i++) {
totals[i] = op(vals[i], totals[i]);
}
}
loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
}
// We can use a simd reduction to accumulate across BM so each thread writes
// the partial output to SM and then each simdgroup does BN / n_simdgroups
// accumulations.
for (int i = 0; i < n_reads; i++) {
shared_vals[offset.y * BN + offset.x + i] = totals[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
for (int i = 0; i < n_outputs; i++) {
totals[i] =
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
}
// Write the output.
if (simd_lane_id == 0) {
size_t out_column = BN * gid.x + out_offset.x;
out += full_idx * reduction_stride + out_column;
if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i];
}
} else {
for (int i = 0; out_column + i < reduction_stride; i++) {
out[i] = totals[i];
}
}
}
}

View File

@@ -1,11 +1,11 @@
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
using namespace metal;
using namespace mlx::steel;
@@ -886,6 +886,9 @@ template <
}
}
// clang-format off
// SDPA full instantiations
#define instantiate_fast_inference_self_attention_kernel( \
itype, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
@@ -922,548 +925,29 @@ instantiate_fast_inference_self_attention_kernel(
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
template <
typename T,
typename T2,
typename T4,
uint16_t TILE_SIZE_CONST,
uint16_t NSIMDGROUPS>
[[kernel]] void fast_inference_sdpa_compute_partials_template(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
const device uint64_t& L [[buffer(3)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
constexpr const uint iter_offset = NSIMDGROUPS * 4;
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
uint kv_head_offset_factor = tid.x;
if (is_gqa) {
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
kv_head_offset_factor = tid.x / q_kv_head_ratio;
}
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
NSIMDGROUPS;
// SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \
template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \
[[kernel]] void sdpa_vector<type, head_dim>( \
const device type* queries [[buffer(0)]], \
const device type* keys [[buffer(1)]], \
const device type* values [[buffer(2)]], \
device type* out [[buffer(3)]], \
const constant int& gqa_factor, \
const constant int& N, \
const constant size_t& k_stride, \
const constant size_t& v_stride, \
const constant float& scale, \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
#pragma clang loop unroll(full)
for (uint i = 0; i < 8; i++) {
smemFlush
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// TODO: multiple query sequence length for speculative decoding
const uint tgroup_query_head_offset =
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
#define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64) \
instantiate_sdpa_vector(type, 96) \
instantiate_sdpa_vector(type, 128)
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
const device T* baseK =
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
const device T* baseQ = Q + tgroup_query_head_offset;
device T4* simdgroupQueryData = (device T4*)baseQ;
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
float threadAccum[ACCUM_PER_GROUP];
#pragma clang loop unroll(full)
for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP;
threadAccumIndex++) {
threadAccum[threadAccumIndex] = -INFINITY;
}
uint KROW_ACCUM_INDEX = 0;
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
const bool LAST_TILE_ALIGNED =
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
T4 thread_data_x4;
T4 thread_data_y4;
if (!LAST_TILE || LAST_TILE_ALIGNED) {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
#pragma clang loop unroll(full)
for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseK + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
} else {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
const uint START_ROW = tid.y * TILE_SIZE_CONST;
const device T* baseKThisHead =
K + tgroup_k_batch_offset + tgroup_k_head_offset;
for (size_t KROW = START_ROW + simd_group_id; KROW < L;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
}
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
#pragma clang loop unroll(full)
for (size_t i = 0; i < P_VEC4; i++) {
thread_data_x4 =
T4(threadAccum[4 * i],
threadAccum[4 * i + 1],
threadAccum[4 * i + 2],
threadAccum[4 * i + 3]);
simdgroup_barrier(mem_flags::mem_none);
thread_data_y4 = simd_sum(thread_data_x4);
if (simd_lane_id == 0) {
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float groupMax;
float lse = 0.f;
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
constexpr const size_t ACCUM_ARRAY_LENGTH =
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
float4 pvals[ACCUM_ARRAY_LENGTH];
#pragma clang loop unroll(full)
for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
accum_array_iter++) {
pvals[accum_array_iter] = float4(-INFINITY);
}
if (TILE_SIZE_CONST == 64) {
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
float2 vals = smemPtrFlt2[simd_lane_id];
vals *= params.INV_ALPHA;
float maxval = max(vals.x, vals.y);
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float2 expf_shifted = exp(vals - groupMax);
float sumExpLocal = expf_shifted.x + expf_shifted.y;
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
float2 local_p_hat = expf_shifted / tgroupExpSum;
pvals[0].x = local_p_hat.x;
pvals[0].y = local_p_hat.y;
smemPtrFlt2[simd_lane_id] = float2(0.f);
}
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
if (TILE_SIZE_LARGER_THAN_64) {
float maxval = -INFINITY;
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
vals *= params.INV_ALPHA;
pvals[i] = vals;
maxval = fmax3(vals.x, vals.y, maxval);
maxval = fmax3(vals.z, vals.w, maxval);
}
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float sumExpLocal = 0.f;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = exp(pvals[i] - groupMax);
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
}
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = pvals[i] / tgroupExpSum;
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
}
}
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
const size_t v_head_offset = kv_head_offset_factor * L * DK;
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
device T* baseV = (device T*)V + v_offset;
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
if (!LAST_TILE || LAST_TILE_ALIGNED) {
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint matrix_load_loop_iter = 0;
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
for (size_t tile_start = simd_group_id;
tile_start < TILE_SIZE_CONST_DIV_8;
tile_start += NSIMDGROUPS) {
simdgroup_matrix<T, 8, 8> tmp;
ulong simdgroup_matrix_offset =
matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
matrix_load_loop_iter++;
};
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
uint loop_iter = 0;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
#pragma clang loop unroll(full)
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
}
if (TILE_SIZE_CONST > 64) {
constexpr const size_t TILE_SIZE_CONST_DIV_128 =
(TILE_SIZE_CONST + 1) / 128;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_iter = 0;
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
T row_sum = 0.f;
for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[i]);
T val = dot(p_local, v_local);
row_sum += val;
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
}
}
} else {
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
constexpr const int ROWS_PER_ITER = 8;
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
int32_t tile_start;
for (tile_start =
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
tile_start < MAX_START_ROW;
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
simdgroup_matrix<T, 8, 8> tmp;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
simdgroup_load(
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(
tmp,
smemV,
elemsPerRowSmem,
matrixOriginSmem,
/* transpose */ false);
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
};
tile_start =
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
const int32_t INT_L = int32_t(L);
for (int row_index = tile_start + simd_group_id; row_index < INT_L;
row_index += NSIMDGROUPS) {
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
const uint elems_per_row_gmem = DK;
const uint col_index_v_gmem =
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
const uint row_index_v_gmem = row_index;
const uint elems_per_row_smem = TILE_SIZE_CONST;
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
const uint row_index_v_smem = simd_lane_id;
const uint scalar_offset_gmem =
row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
const uint scalar_offset_smem =
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
smemV[scalar_offset_smem] = vdata;
smem_col_index += NSIMDGROUPS;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
for (size_t smem_row_index = simd_group_id;
smem_row_index < ROWS_PER_ITER;
smem_row_index += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[smem_row_index] = float(row_sum);
}
}
if (TILE_SIZE_CONST > 64) {
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_count = 0;
for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
row_index += NSIMDGROUPS) {
T row_sum = 0.f;
for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
tile_iters++) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local =
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[tile_iters]);
row_sum += dot(p_local, v_local);
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
float(row_sum);
loop_count++;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
float4 vals = *(oPartialVec4 + simd_lane_id);
device float* oPartialGmem =
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
oPartialGmemVec4[simd_lane_id] = vals;
}
if (simd_group_id == 0 && simd_lane_id == 0) {
const uint tileIndex = tid.y;
const uint gmem_partial_scalar_offset =
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
tileIndex;
p_lse[gmem_partial_scalar_offset] = lse;
p_maxes[gmem_partial_scalar_offset] = groupMax;
}
}
#define instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, nsimdgroups) \
template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
fast_inference_sdpa_compute_partials_template< \
itype, \
itype2, \
itype4, \
tile_size, \
nsimdgroups>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
const device uint64_t& L [[buffer(3)]], \
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype* threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);
// clang-format off
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
itype, itype2, itype4, tile_size) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 4) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 8) // clang-format on
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
512);
template <typename T>
void fast_inference_sdpa_reduce_tiles_template(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device T* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
constexpr const int DK = 128;
const ulong offset_rows =
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
const device float* p_lse_row = p_lse + offset_rows;
const device float* p_rowmax_row = p_maxes + offset_rows;
// reserve some number of registers. this constitutes an assumption on max
// value of KV TILES.
constexpr const uint8_t reserve = 128;
float p_lse_regs[reserve];
float p_rowmax_regs[reserve];
float weights[reserve];
float true_max = -INFINITY;
for (size_t i = 0; i < params.KV_TILES; i++) {
p_lse_regs[i] = float(*(p_lse_row + i));
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
true_max = fmax(p_rowmax_regs[i], true_max);
weights[i] = exp(p_lse_regs[i]);
}
float denom = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
weights[i] *= exp(p_rowmax_regs[i] - true_max);
denom += weights[i];
}
const device float* O_partials_with_offset = O_partials +
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
tid.x * DK * params.KV_TILES;
float o_value = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
float val = *(O_partials_with_offset + i * DK + lid.x);
o_value += val * weights[i] / denom;
}
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
O_gmem[lid.x] = T(o_value);
return;
}
kernel void fast_inference_sdpa_reduce_tiles_float(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device float* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<float>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}
kernel void fast_inference_sdpa_reduce_tiles_half(
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device half* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<half>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}
instantiate_sdpa_vector_heads(float)
instantiate_sdpa_vector_heads(bfloat16_t)
instantiate_sdpa_vector_heads(float16_t)
// clang-format on

View File

@@ -1,7 +1,38 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#define DEFINE_SIMD_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_scan(T val) { \
return simd_scan_impl(val); \
} \
\
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
T simd_scan(T val) { \
for (int i = 1; i <= 16; i *= 2) { \
val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \
} \
return val; \
}
#define DEFINE_SIMD_EXCLUSIVE_SCAN() \
template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
T simd_exclusive_scan(T val) { \
return simd_exclusive_scan_impl(val); \
} \
\
template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
T simd_exclusive_scan(T val) { \
val = simd_scan(val); \
return simd_shuffle_and_fill_up(val, init, 1); \
}
template <typename U>
struct CumSum {
DEFINE_SIMD_SCAN()
DEFINE_SIMD_EXCLUSIVE_SCAN()
static constexpr constant U init = static_cast<U>(0);
template <typename T>
@@ -9,17 +40,20 @@ struct CumSum {
return a + b;
}
U simd_scan(U x) {
U simd_scan_impl(U x) {
return simd_prefix_inclusive_sum(x);
}
U simd_exclusive_scan(U x) {
U simd_exclusive_scan_impl(U x) {
return simd_prefix_exclusive_sum(x);
}
};
template <typename U>
struct CumProd {
DEFINE_SIMD_SCAN()
DEFINE_SIMD_EXCLUSIVE_SCAN()
static constexpr constant U init = static_cast<U>(1.0f);
template <typename T>
@@ -27,11 +61,11 @@ struct CumProd {
return a * b;
}
U simd_scan(U x) {
U simd_scan_impl(U x) {
return simd_prefix_inclusive_product(x);
}
U simd_exclusive_scan(U x) {
U simd_exclusive_scan_impl(U x) {
return simd_prefix_exclusive_product(x);
}
};
@@ -47,7 +81,7 @@ struct CumProd<bool> {
bool simd_scan(bool x) {
for (int i = 1; i <= 16; i *= 2) {
bool other = simd_shuffle_up(x, i);
bool other = simd_shuffle_and_fill_up(x, init, i);
x &= other;
}
return x;
@@ -70,7 +104,7 @@ struct CumMax {
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
U other = simd_shuffle_and_fill_up(x, init, i);
x = (x >= other) ? x : other;
}
return x;
@@ -93,7 +127,7 @@ struct CumMin {
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
U other = simd_shuffle_and_fill_up(x, init, i);
x = (x <= other) ? x : other;
}
return x;
@@ -178,20 +212,22 @@ template <
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
Op op;
// Position the pointers
in += (gid / lsize) * axis_size;
out += (gid / lsize) * axis_size;
size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size;
in += offset;
out += offset;
// Compute the number of simd_groups
uint simd_groups = lsize / simd_size;
uint simd_groups = lsize.x / simd_size;
// Allocate memory
U prefix = Op::init;
@@ -210,9 +246,9 @@ template <
// value
// Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {
// Compute the block offset
uint offset = r * lsize * N_READS + lid * N_READS;
uint offset = r * lsize.x * N_READS + lid.x * N_READS;
// Read the values
if (reverse) {
@@ -275,7 +311,7 @@ template <
values, out + axis_size - offset - N_READS, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
if (lid.x == 0 && offset == 0) {
out[axis_size - 1] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
@@ -298,7 +334,7 @@ template <
values, out + offset, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
if (lid.x == 0 && offset == 0) {
out[0] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
@@ -332,86 +368,98 @@ template <
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]) {
const constant size_t& stride_blocks [[buffer(4)]],
uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int simd_size = 32;
constexpr int BM = 32;
constexpr int BN = 32;
constexpr int BN_pad = 32 + 16 / sizeof(U);
constexpr int n_simds = BN / N_READS;
constexpr int n_scans = BN / n_simds;
Op op;
// Allocate memory
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
U values[N_READS];
U prefix[N_READS];
for (int i = 0; i < N_READS; i++) {
threadgroup U read_buffer[BM * BN_pad];
U values[n_scans];
U prefix[n_scans];
for (int i = 0; i < n_scans; i++) {
prefix[i] = Op::init;
}
// Compute offsets
int offset = gid.y * axis_size * stride;
int global_index_x = gid.x * lsize.y * N_READS;
size_t full_gid = gid.y + gsize.y * size_t(gid.z);
size_t offset = full_gid / stride_blocks * axis_size * stride;
size_t global_index_x = full_gid % stride_blocks * BN;
uint read_offset_y = (lid.x * N_READS) / BN;
uint read_offset_x = (lid.x * N_READS) % BN;
uint scan_offset_y = simd_lane_id;
uint scan_offset_x = simd_group_id * n_scans;
for (uint j = 0; j < axis_size; j += simd_size) {
uint stride_limit = stride - global_index_x;
in += offset + global_index_x + read_offset_x;
out += offset + global_index_x + read_offset_x;
threadgroup U* read_into =
read_buffer + read_offset_y * BN_pad + read_offset_x;
threadgroup U* read_from =
read_buffer + scan_offset_y * BN_pad + scan_offset_x;
for (uint j = 0; j < axis_size; j += BM) {
// Calculate the indices for the current thread
uint index_y = j + lid.y;
uint index_y = j + read_offset_y;
uint check_index_y = index_y;
uint index_x = global_index_x + lid.x * N_READS;
if (reverse) {
index_y = axis_size - 1 - index_y;
}
// Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
read_into[i] = in[index_y * stride + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
read_into[i] = in[index_y * stride + i];
} else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
Op::init;
read_into[i] = Op::init;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read strided into registers
for (int i = 0; i < N_READS; i++) {
values[i] =
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
for (int i = 0; i < n_scans; i++) {
values[i] = read_from[i];
}
// Do we need the following barrier? Shouldn't all simd threads execute
// simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan
for (int i = 0; i < N_READS; i++) {
for (int i = 0; i < n_scans; i++) {
values[i] = op.simd_scan(values[i]);
values[i] = op(values[i], prefix[i]);
prefix[i] = simd_shuffle(values[i], simd_size - 1);
}
// Write to SM
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
values[i];
for (int i = 0; i < n_scans; i++) {
read_from[i] = values[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write to device memory
if (!inclusive) {
if (check_index_y == 0) {
if ((index_x + N_READS) < stride) {
if ((read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] = Op::init;
out[index_y * stride + i] = Op::init;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = Op::init;
if ((read_offset_x + i) < stride_limit) {
out[index_y * stride + i] = Op::init;
}
}
}
@@ -424,16 +472,14 @@ template <
check_index_y += 1;
}
}
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
out[index_y * stride + i] = read_into[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
out[index_y * stride + i] = read_into[i];
}
}
}

View File

@@ -13,15 +13,15 @@ using namespace metal;
#define instantiate_contiguous_scan( \
name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("contig_scan_" #name)]] [[kernel]] void \
template [[host_name("contig_scan_" #name)]] [[kernel]] void \
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& axis_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
@@ -33,10 +33,12 @@ using namespace metal;
device otype* out [[buffer(1)]], \
const constant size_t& axis_size [[buffer(2)]], \
const constant size_t& stride [[buffer(3)]], \
uint2 gid [[thread_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]);
const constant size_t& stride_blocks [[buffer(4)]], \
uint3 gid [[threadgroup_position_in_grid]], \
uint3 gsize [[threadgroups_per_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
@@ -52,51 +54,51 @@ instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSu
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)
//instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4)
instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4)
instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4)
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum, 2)
instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4)
instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4)
//instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4)
instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4)
instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4)
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd, 2)
instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4)
instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4)
//instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4)
instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4)
instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4)
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax, 2)
instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4)
instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4)
//instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4)
instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4)
instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4)
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on
instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin, 2) // clang-format on

View File

@@ -4,73 +4,54 @@
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_1d_index_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
size_t out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
if (upd_ndim > 1) {
auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
out_idx += out_offset;
} else {
out_idx += gid.x;
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
}
template <typename T, typename IdxT, typename Op, int NIDX>
template <
typename T,
typename IdxT,
typename Op,
int NIDX,
bool UPD_ROW_CONTIG,
int NWORK>
METAL_FUNC void scatter_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const device T* updates,
device mlx_atomic<T>* out,
const constant int* upd_shape,
const constant size_t* upd_strides,
const constant size_t& upd_ndim,
const constant size_t& upd_size,
const constant int* out_shape,
const constant size_t* out_strides,
const constant size_t& out_ndim,
const constant int* axes,
const constant size_t& idx_size,
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto ind_idx = gid.y * NWORK;
size_t out_offset = 0;
if (upd_size > 1) {
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
out_idx += out_offset;
out_offset =
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
}
auto upd_idx =
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx);
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
size_t out_idx = out_offset;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = indices.row_contiguous[i]
? ind_idx
: elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto upd_idx = ind_idx * upd_size + gid.x;
if constexpr (!UPD_ROW_CONTIG) {
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
}
op.atomic_update(out, updates[upd_idx], out_idx);
}
}

View File

@@ -0,0 +1,116 @@
// Copyright © 2024 Apple Inc.
#include <metal_simdgroup>
using namespace metal;
template <typename T, int D>
[[kernel]] void sdpa_vector(
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device T* out [[buffer(3)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
const int stride = BN * D;
typedef float U;
thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
out += head_idx * D + simd_gid * elem_per_thread;
// Read the query and 0 the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
}
for (int i = 0; i < elem_per_thread; i++) {
o[i] = 0;
}
U max_score = -INFINITY;
U sum_exp_score = 0;
// For each key
for (int i = simd_gid; i < N; i += BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
// Move the pointers to the next kv
keys += stride;
values += stride;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = max_scores[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// And write the output
if (simd_lid == 0) {
for (int i = 0; i < elem_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}

View File

@@ -142,8 +142,8 @@ implicit_gemm_conv_2d_general(
// Store results to device memory
{
// Adjust for simdgroup and thread locatio
int offset_m = c_row + mma_op.sm + mma_op.tm;
int offset_n = c_col + mma_op.sn + mma_op.tn;
int offset_m = c_row + mma_op.sm;
int offset_n = c_col + mma_op.sn;
C += offset_n;
if (offset_n >= gemm_params->N)
@@ -169,17 +169,17 @@ implicit_gemm_conv_2d_general(
STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum =
mma_op.results[i * mma_t::TN + j].thread_elements();
thread const auto& accum = mma_op.Ctile.frag_at(i, j);
int offset = offset_cm + (j * mma_t::TN_stride);
// Apply epilogue and output C
if (j * mma_t::TN_stride < diff) {
C[offset] = Epilogue::apply(accum[0]);
}
constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;
if (j * mma_t::TN_stride + 1 < diff) {
C[offset + 1] = Epilogue::apply(accum[1]);
// Apply epilogue and output C
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * mma_t::TN_stride + k) < diff) {
C[offset + k] = Epilogue::apply(accum[k]);
}
}
}
}

View File

@@ -36,11 +36,11 @@
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);

View File

@@ -8,6 +8,7 @@
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h"
using namespace metal;
@@ -18,6 +19,347 @@ using namespace metal;
namespace mlx {
namespace steel {
template <typename T, int kFragRows_, int kFragCols_>
struct BaseMMAFrag {
static_assert(
kFragRows_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
static_assert(
kFragCols_ == 8,
"Only 8 x 8 fragment matrices are currently supported");
};
template <typename T>
struct BaseMMAFrag<T, 8, 8> {
STEEL_CONST int kFragRows = 8;
STEEL_CONST int kFragCols = 8;
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
STEEL_CONST int kElemRows = 1;
STEEL_CONST int kElemCols = 2;
static_assert(
kElemRows * kElemCols == kElemsPerFrag,
"MMAFrag shape is not consistent with MMAFrag size");
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
typedef metal::vec<T, kElemsPerFrag> frag_type;
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
[[thread_index_in_simdgroup]]) {
const short qid = simd_lane_id / 4;
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
return short2{fn, fm};
}
template <typename SrcPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
}
}
}
template <
typename SrcPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void load_safe(
thread frag_type& dst,
SrcPtrType src,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[i * kElemCols + j] =
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
} else {
dst[i * kElemCols + j] = T(0);
}
}
}
}
template <typename DstPtrType, typename StrX, typename StrY>
METAL_FUNC static constexpr void
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
}
}
}
template <
typename DstPtrType,
typename StrX,
typename StrY,
typename LimX,
typename LimY,
typename OffX,
typename OffY>
METAL_FUNC static constexpr void store_safe(
const thread frag_type& src,
DstPtrType dst,
StrX str_x,
StrY str_y,
LimX lim_x,
LimY lim_y,
OffX off_x = Int<0>{},
OffY off_y = Int<0>{}) {
using U = pointer_element_t<DstPtrType>;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kElemRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kElemCols; j++) {
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
static_cast<U>(src[i * kElemCols + j]);
}
}
}
}
METAL_FUNC static constexpr void mma(
thread frag_type& D,
thread frag_type& A,
thread frag_type& B,
thread frag_type& C) {
mat_type D_mat;
mat_type A_mat;
mat_type B_mat;
mat_type C_mat;
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
mma(D_mat, A_mat, B_mat, C_mat);
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
}
METAL_FUNC static constexpr void mma(
thread mat_type& D,
thread mat_type& A,
thread mat_type& B,
thread mat_type& C) {
simdgroup_multiply_accumulate(D, A, B, C);
}
};
template <
typename T,
int kTileRows_,
int kTileCols_,
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
struct MMATile {
using MMAFrag_t = MMAFrag_;
using elem_type = T;
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
STEEL_CONST int kTileRows = kTileRows_;
STEEL_CONST int kTileCols = kTileCols_;
STEEL_CONST int kRows = kTileRows * kFragRows;
STEEL_CONST int kCols = kTileCols * kFragCols;
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
typedef typename MMAFrag_t::mat_type mat_type;
typedef typename MMAFrag_t::frag_type frag_type;
frag_type val_frags[kNumFrags] = {frag_type(0)};
METAL_FUNC MMATile() thread {}
METAL_FUNC constexpr void clear() {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kNumFrags; ++i) {
val_frags[i] = frag_type(0);
}
}
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
return val_frags[i * kTileCols + j];
}
METAL_FUNC constexpr const thread frag_type& frag_at(
const short i,
const short j) const {
return val_frags[i * kTileCols + j];
}
METAL_FUNC mat_type mat_at(const short i, const short j) {
mat_type val_mat;
STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
}
return val_mat;
}
METAL_FUNC thread elem_type* elems() {
return reinterpret_cast<thread elem_type*>(val_frags);
}
METAL_FUNC const thread elem_type* elems() const {
return reinterpret_cast<const thread elem_type*>(val_frags);
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void load(const threadgroup U* src) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(
src[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y, int str_x, int str_y>
METAL_FUNC void store(threadgroup U* dst) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(
dst[(i * kFragRows) * w_x * str_x +
(j * kFragCols) * w_y * str_y]),
Int<str_x>{},
Int<str_y>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void load(const device U* src, const int ld) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::load(
frag_at(i, j),
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void store(device U* dst, const int ld) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < kTileCols; ++j) {
MMAFrag_t::store(
frag_at(i, j),
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
ld,
Int<1>{});
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::load_safe(
frag_at(i, j),
src,
ld,
Int<1>{},
src_tile_dims.y,
src_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
template <typename U, int w_x, int w_y>
METAL_FUNC void
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < kTileRows; ++i) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < kTileCols; ++j) {
MMAFrag_t::store_safe(
frag_at(i, j),
dst,
ld,
Int<1>{},
dst_tile_dims.y,
dst_tile_dims.x,
(i * kFragRows) * w_x,
(j * kFragCols) * w_y);
}
}
}
};
template <typename T, typename U, int M, int N, int K>
METAL_FUNC void tile_matmad(
thread MMATile<T, M, N>& D,
thread MMATile<U, M, K>& A,
thread MMATile<U, K, N>& B,
thread MMATile<T, M, N>& C) {
STEEL_PRAGMA_UNROLL
for (short m = 0; m < M; ++m) {
STEEL_PRAGMA_UNROLL
for (short n = 0; n < N; ++n) {
short n_serp = (m % 2) ? (N - 1 - n) : n;
STEEL_PRAGMA_UNROLL
for (short k = 0; k < K; ++k) {
MMATile<T, M, N>::MMAFrag_t::mma(
D.frag_at(m, n_serp),
A.frag_at(m, k),
B.frag_at(k, n_serp),
C.frag_at(m, n_serp));
}
}
}
}
template <
typename T,
typename U,
@@ -33,39 +375,38 @@ template <
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMA {
// MMAFrag size
STEEL_CONST short kFragSize = 8;
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = 8 * WM;
STEEL_CONST short TM_stride = kFragSize * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = 8 * WN;
STEEL_CONST short TN_stride = kFragSize * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Strides of A, B along reduction axis
STEEL_CONST short simd_stride_a = {
transpose_a ? TM_stride : TM_stride * lda_tgp};
STEEL_CONST short simd_stride_b = {
transpose_b ? TN_stride * ldb_tgp : TN_stride};
// Threadgroup A strides
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
// Jump between elements
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
// Threadgroup B strides
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
// Threadgroup strides along K
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
// Offsets within threadgroup
const short tm;
const short tn;
short sm;
short sn;
@@ -75,18 +416,21 @@ struct BlockMMA {
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
// Determine thread position in simdgroup matrix
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
short tm = kFragSize * (simd_group_id / WN);
short tn = kFragSize * (simd_group_id % WN);
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
sm = simd_coord.y;
sn = simd_coord.x;
// Determine thread and simdgroup offset
As_offset =
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
Bs_offset =
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
sm += tm;
sn += tn;
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
@@ -95,47 +439,20 @@ struct BlockMMA {
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
// Iterate over BK in blocks of kFragSize
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
for (short kk = 0; kk < BK; kk += kFragSize) {
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
}
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
}
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
Bsimd[j_serp],
results[i * TN + j_serp]);
}
}
tile_matmad(Ctile, Atile, Btile, Ctile);
// Progress to next simdgroup tile
As += tile_stride_a;
@@ -144,58 +461,35 @@ struct BlockMMA {
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* D, const int ldd) const {
// Adjust for simdgroup and thread location
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
METAL_FUNC void store_result(device U* D, const int ldd) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out D
D[offset] = outs[0];
D[offset + 1] = outs[1];
}
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += sm * ldd + sn;
Ctile.template store<U, WM, WN>(D, ldd);
}
METAL_FUNC void
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) const {
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
}
// Adjust for simdgroup and thread location
D += (sm + tm) * ldd + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
D += sm * ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
}
/* Apply epilogue */
@@ -203,16 +497,8 @@ struct BlockMMA {
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
// Apply epilogue
accum[0] = epilogue_op.apply(accum[0]);
accum[1] = epilogue_op.apply(accum[1]);
}
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
}
}
@@ -224,7 +510,7 @@ struct BlockMMA {
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
C += (sm)*ldc + (sn)*fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
@@ -232,12 +518,14 @@ struct BlockMMA {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
thread auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Apply epilogue
accum[0] = epilogue_op.apply(accum[0], C[offset_c]);
accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
@@ -251,8 +539,8 @@ struct BlockMMA {
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
dst_tile_dims -= short2(tn + sn, sm + tm);
C += (sm)*ldc + (sn)*fdc;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
@@ -263,22 +551,26 @@ struct BlockMMA {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
thread auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Read C
U c_elems[2] = {0};
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
if ((j * TN_stride + 1) < dst_tile_dims.x) {
c_elems[0] = C[offset_c];
c_elems[1] = C[offset_c + fdc];
} else if ((j * TN_stride) < dst_tile_dims.x) {
c_elems[0] = C[offset_c];
// Read C
U c_elems[kelems] = {0};
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
c_elems[k] = C[offset_c + k * fdc];
}
}
// Apply epilogue
accum[0] = epilogue_op.apply(accum[0], c_elems[0]);
accum[1] = epilogue_op.apply(accum[1], c_elems[1]);
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
}
}
}
}
@@ -292,8 +584,10 @@ struct BlockMMA {
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
@@ -301,18 +595,15 @@ struct BlockMMA {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
thread const auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {
epilogue_op.apply(accum[0], C[offset_c]),
epilogue_op.apply(accum[1], C[offset_c + fdc])};
// Write out D
D[offset_d] = outs[0];
D[offset_d + 1] = outs[1];
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}
@@ -326,30 +617,32 @@ struct BlockMMA {
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
C += (sm)*ldc + (sn)*fdc;
D += (sm)*ldd + sn;
dst_tile_dims -= short2(sn, sm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
thread const auto& accum = Ctile.frag_at(i, j);
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
// Apply epilogue
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * TN_stride + k) < dst_tile_dims.x) {
D[offset_d + k] =
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
}
}
}
}

View File

@@ -0,0 +1,96 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/utils/type_traits.h"
#pragma METAL internals : enable
namespace mlx {
namespace steel {
///////////////////////////////////////////////////////////////////////////////
// Integral constant with casting
///////////////////////////////////////////////////////////////////////////////
template <typename T, T v>
struct integral_constant {
static constexpr constant T value = v;
using value_type = T;
using type = integral_constant;
METAL_FUNC constexpr operator value_type() const noexcept {
return value;
}
// METAL_FUNC constexpr value_type operator()() const noexcept {
// return value;
// }
};
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
template <class T>
struct is_integral : bool_constant<metal::is_integral<T>::value> {};
template <class T, T v>
struct is_integral<integral_constant<T, v>>
: bool_constant<metal::is_integral<T>::value> {};
template <typename T>
constexpr constant bool is_integral_v = is_integral<T>::value;
template <int val>
using Int = integral_constant<int, val>;
///////////////////////////////////////////////////////////////////////////////
// Binary Operators on Integral constants
///////////////////////////////////////////////////////////////////////////////
#define integral_const_binop(__op__, __operator__) \
template <typename T, T tv, typename U, U uv> \
METAL_FUNC constexpr auto __operator__( \
integral_constant<T, tv>, integral_constant<U, uv>) { \
constexpr auto res = tv __op__ uv; \
return integral_constant<decltype(res), res>{}; \
}
integral_const_binop(+, operator+);
integral_const_binop(-, operator-);
integral_const_binop(*, operator*);
integral_const_binop(/, operator/);
integral_const_binop(==, operator==);
integral_const_binop(!=, operator!=);
integral_const_binop(<, operator<);
integral_const_binop(>, operator>);
integral_const_binop(<=, operator<=);
integral_const_binop(>=, operator>=);
integral_const_binop(&&, operator&&);
integral_const_binop(||, operator||);
#undef integral_const_binop
///////////////////////////////////////////////////////////////////////////////
// Reduction operators
///////////////////////////////////////////////////////////////////////////////
template <typename T>
METAL_FUNC constexpr T sum(T x) {
return x;
}
template <typename T, typename... Us>
METAL_FUNC constexpr auto sum(T x, Us... us) {
return x + sum(us...);
}
} // namespace steel
} // namespace mlx
#pragma METAL internals : disable

View File

@@ -0,0 +1,55 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_stdlib>
#pragma METAL internals : enable
namespace metal {
template <typename T>
struct is_empty : metal::bool_constant<__is_empty(T)> {};
#ifdef __cpp_variable_templates
template <typename T>
constexpr constant bool is_empty_v = is_empty<T>::value;
#endif
template <typename... Ts>
struct make_void {
typedef void type;
};
template <typename... Ts>
using void_t = typename make_void<Ts...>::type;
template <class T>
struct is_static : metal::bool_constant<is_empty<remove_cv_t<T>>::value> {};
template <typename T>
struct pointer_element {};
template <typename T>
struct pointer_element<thread T*> {
using type = remove_cv_t<T>;
};
template <typename T>
struct pointer_element<device T*> {
using type = remove_cv_t<T>;
};
template <typename T>
struct pointer_element<constant T*> {
using type = remove_cv_t<T>;
};
template <typename T>
struct pointer_element<threadgroup T*> {
using type = remove_cv_t<T>;
};
template <typename T>
using pointer_element_t = typename pointer_element<remove_cv_t<T>>::type;
} // namespace metal
#pragma METAL internals : disable

View File

@@ -12,11 +12,10 @@
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
instantiate_kernel("gn4_" #op #tname, ternary_g, type, op, 4) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op) \
instantiate_kernel("g3_" #op #tname, ternary_g_nd3, type, op)
#define instantiate_ternary_types(op) \
instantiate_ternary_all(op, bool_, bool) \

View File

@@ -1,27 +1,27 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename Op>
template <typename T, typename U, typename Op>
[[kernel]] void unary_v(
device const T* in,
device T* out,
device U* out,
uint index [[thread_position_in_grid]]) {
out[index] = Op()(in[index]);
}
template <typename T, typename Op>
template <typename T, typename U, typename Op>
[[kernel]] void unary_v2(
device const T* in,
device T* out,
device U* out,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
out[offset] = Op()(in[offset]);
}
template <typename T, typename Op, int N = 1>
template <typename T, typename U, typename Op, int N = 1>
[[kernel]] void unary_g(
device const T* in,
device T* out,
device U* out,
constant const int* in_shape,
constant const size_t* in_strides,
device const int& ndim,

View File

@@ -5,27 +5,30 @@
#include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, unary_v, type, op) \
instantiate_kernel("v2_" #op #tname, unary_v2, type, op) \
instantiate_kernel("gn4_" #op #tname, unary_g, type, op, 4) \
instantiate_kernel("g_" #op #tname, unary_g, type, op)
#define instantiate_unary_all(op, in_tname, out_tname, in_type, out_type) \
instantiate_kernel("v_" #op #in_tname #out_tname, unary_v, in_type, out_type, op) \
instantiate_kernel("v2_" #op #in_tname #out_tname, unary_v2, in_type, out_type, op) \
instantiate_kernel("gn4_" #op #in_tname #out_tname, unary_g, in_type, out_type, op, 4)
#define instantiate_unary_all_same(op, tname, type) \
instantiate_unary_all(op, tname, tname, type, type)
#define instantiate_unary_float(op) \
instantiate_unary_all(op, float16, half) \
instantiate_unary_all(op, float32, float) \
instantiate_unary_all(op, bfloat16, bfloat16_t)
instantiate_unary_all_same(op, float16, half) \
instantiate_unary_all_same(op, float32, float) \
instantiate_unary_all_same(op, bfloat16, bfloat16_t)
#define instantiate_unary_types(op) \
instantiate_unary_all(op, bool_, bool) \
instantiate_unary_all(op, uint8, uint8_t) \
instantiate_unary_all(op, uint16, uint16_t) \
instantiate_unary_all(op, uint32, uint32_t) \
instantiate_unary_all(op, uint64, uint64_t) \
instantiate_unary_all(op, int8, int8_t) \
instantiate_unary_all(op, int16, int16_t) \
instantiate_unary_all(op, int32, int32_t) \
instantiate_unary_all(op, int64, int64_t) \
instantiate_unary_all_same(op, bool_, bool) \
instantiate_unary_all_same(op, uint8, uint8_t) \
instantiate_unary_all_same(op, uint16, uint16_t) \
instantiate_unary_all_same(op, uint32, uint32_t) \
instantiate_unary_all_same(op, uint64, uint64_t) \
instantiate_unary_all_same(op, int8, int8_t) \
instantiate_unary_all_same(op, int16, int16_t) \
instantiate_unary_all_same(op, int32, int32_t) \
instantiate_unary_all_same(op, int64, int64_t) \
instantiate_unary_float(op)
instantiate_unary_types(Abs)
@@ -59,17 +62,19 @@ instantiate_unary_float(Tan)
instantiate_unary_float(Tanh)
instantiate_unary_float(Round)
instantiate_unary_all(Abs, complex64, complex64_t)
instantiate_unary_all(Conjugate, complex64, complex64_t)
instantiate_unary_all(Cos, complex64, complex64_t)
instantiate_unary_all(Cosh, complex64, complex64_t)
instantiate_unary_all(Exp, complex64, complex64_t)
instantiate_unary_all(Negative, complex64, complex64_t)
instantiate_unary_all(Sign, complex64, complex64_t)
instantiate_unary_all(Sin, complex64, complex64_t)
instantiate_unary_all(Sinh, complex64, complex64_t)
instantiate_unary_all(Tan, complex64, complex64_t)
instantiate_unary_all(Tanh, complex64, complex64_t)
instantiate_unary_all(Round, complex64, complex64_t)
instantiate_unary_all_same(Abs, complex64, complex64_t)
instantiate_unary_all_same(Conjugate, complex64, complex64_t)
instantiate_unary_all_same(Cos, complex64, complex64_t)
instantiate_unary_all_same(Cosh, complex64, complex64_t)
instantiate_unary_all_same(Exp, complex64, complex64_t)
instantiate_unary_all_same(Negative, complex64, complex64_t)
instantiate_unary_all_same(Sign, complex64, complex64_t)
instantiate_unary_all_same(Sin, complex64, complex64_t)
instantiate_unary_all_same(Sinh, complex64, complex64_t)
instantiate_unary_all_same(Tan, complex64, complex64_t)
instantiate_unary_all_same(Tanh, complex64, complex64_t)
instantiate_unary_all_same(Round, complex64, complex64_t)
instantiate_unary_all(Real, complex64, float32, complex64_t, float)
instantiate_unary_all(Imag, complex64, float32, complex64_t, float)
instantiate_unary_all(LogicalNot, bool_, bool) // clang-format on
instantiate_unary_all_same(LogicalNot, bool_, bool) // clang-format on

View File

@@ -238,6 +238,13 @@ struct Floor {
};
};
struct Imag {
template <typename T>
T operator()(T x) {
return x.imag;
};
};
struct Log {
template <typename T>
T operator()(T x) {
@@ -280,6 +287,13 @@ struct Negative {
};
};
struct Real {
template <typename T>
T operator()(T x) {
return x.real;
};
};
struct Round {
template <typename T>
T operator()(T x) {

View File

@@ -320,3 +320,63 @@ inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
return complex64_t(
simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
}
inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
}
inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {
return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
}
inline bool simd_shuffle_up(bool data, uint16_t delta) {
return simd_shuffle_up(static_cast<uint32_t>(data), delta);
}
inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) {
return complex64_t(
simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta));
}
inline uint64_t
simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {
return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(
as_type<uint2>(data), as_type<uint2>(filling), delta));
}
inline int64_t
simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {
return as_type<int64_t>(metal::simd_shuffle_and_fill_up(
as_type<uint2>(data), as_type<uint2>(filling), delta));
}
inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {
return simd_shuffle_and_fill_up(
static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);
}
inline complex64_t simd_shuffle_and_fill_up(
complex64_t data,
complex64_t filling,
uint16_t delta) {
return complex64_t(
simd_shuffle_and_fill_up(data.real, filling.real, delta),
simd_shuffle_and_fill_up(data.imag, filling.imag, delta));
}
inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {
return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
}
inline int64_t simd_shuffle(int64_t data, uint16_t lane) {
return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
}
inline bool simd_shuffle(bool data, uint16_t lane) {
return simd_shuffle(static_cast<uint32_t>(data), lane);
}
inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
return complex64_t(
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
}

View File

@@ -88,6 +88,83 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Steel matmul fallback
///////////////////////////////////////////////////////////////////////////////
#define GEMM_TPARAM_MACRO(devc) \
if (devc == 'g') { /* Small device */ \
if (!transpose_a && transpose_b) { /* nt */ \
bm = 64; \
bn = 32; \
bk = 32; \
wm = 2; \
wn = 2; \
} else if (out.dtype() != float32) { /* half and bfloat */ \
bm = 64; \
bn = 64; \
bk = 16; \
wm = 1; \
wn = 2; \
} \
} else if (devc == 'd') { /* Large device */ \
if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \
if (out.dtype() != float32) { /* half and bfloat */ \
if (2 * std::max(M, N) > K) { /* Reasonable K */ \
bm = 64; \
bn = 64; \
bk = 16; \
wm = 1; \
wn = 2; \
} else if (!transpose_a && transpose_b) { /* nt with large k */ \
bm = 64; \
bn = 32; \
bk = 32; \
wm = 2; \
wn = 2; \
} else { /* nn with large K */ \
bm = 32; \
bn = 64; \
bk = 16; \
wm = 1; \
wn = 2; \
} \
} /* float takes default */ \
} else { /* smaller matmul */ \
if (out.dtype() != float32) { /* half and bfloat */ \
if (!transpose_a && transpose_b) { /* nt */ \
bm = 64; \
bn = 32; \
bk = 32; \
wm = 2; \
wn = 2; \
} else { /* nn */ \
bm = 64; \
bn = 64; \
bk = 16; \
wm = 1; \
wn = 2; \
} \
} else { /* floats */ \
if (!transpose_a && transpose_b) { /* nt */ \
bm = 32; \
bn = 64; \
bk = 16; \
wm = 1; \
wn = 2; \
} else { /* nn */ \
bm = 64; \
bn = 32; \
bk = 32; \
wm = 2; \
wn = 2; \
} \
} \
} \
} else { /* Medium device */ \
bm = 64; \
bn = 64; \
bk = 16; \
wm = 2; \
wn = 2; \
}
void steel_matmul_regular(
const Stream& s,
metal::Device& d,
@@ -112,19 +189,11 @@ void steel_matmul_regular(
using namespace mlx::steel;
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
@@ -226,13 +295,8 @@ void steel_matmul_regular(
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
// Record copies
d.add_temporaries(std::move(copies), s.index);
}
void steel_matmul(
@@ -382,12 +446,7 @@ void steel_matmul(
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -435,8 +494,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
d.add_temporary(std::move(zero), s.index);
return;
}
@@ -588,12 +646,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
/////////////////////////////////////////////////////////////////////////////
@@ -798,12 +851,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -916,12 +964,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -929,19 +972,11 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Regular addmm dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
@@ -1056,12 +1091,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -1080,8 +1110,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
d.add_temporary(std::move(zero), s.index);
return;
}
@@ -1356,12 +1385,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -1471,13 +1495,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -1496,8 +1514,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
d.add_temporary(std::move(zero), s.index);
return;
}
@@ -1703,12 +1720,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
return;
}
@@ -1716,19 +1728,11 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int bm = 64, bn = 64, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
char devc = d.get_architecture().back();
GEMM_TPARAM_MACRO(devc)
// Prepare kernel name
std::ostringstream kname;
@@ -1847,13 +1851,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -1,5 +1,7 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "mlx/backend/metal/device.h"
namespace mlx::core {

View File

@@ -63,6 +63,19 @@ size_t set_cache_limit(size_t limit);
/* Clear the memory cache. */
void clear_cache();
/* Set the wired size limit.
*
* Note, this function is only useful for macOS 15.0 or higher.
*
* The wired limit is the total size in bytes of memory that will be kept
* resident. The default value is ``0``.
*
* Setting a wired limit larger than system wired limit is an error.
*
* Returns the previous wired limit.
* */
size_t set_wired_limit(size_t limit);
/** Capture a GPU trace, saving it to an absolute file `path` */
void start_capture(std::string path = "");
void stop_capture();

View File

@@ -16,6 +16,7 @@ MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype,
Dtype,
const std::string) {
return d.get_kernel(kernel_name);
}
@@ -96,6 +97,8 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string&,
const std::string&,
const array&) {
return d.get_kernel(kernel_name);
}

View File

@@ -91,12 +91,8 @@ void RMSNorm::eval_gpu(
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
void RMSNormVJP::eval_gpu(
@@ -204,10 +200,7 @@ void RMSNormVJP::eval_gpu(
strided_reduce_general_dispatch(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
d.add_temporaries(std::move(copies), s.index);
}
void LayerNorm::eval_gpu(
@@ -292,12 +285,8 @@ void LayerNorm::eval_gpu(
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
void LayerNormVJP::eval_gpu(
@@ -425,10 +414,7 @@ void LayerNormVJP::eval_gpu(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core::fast

View File

@@ -273,7 +273,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
// organize into grid nkeys x elem_per_key
MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size group_dims = MTL::Size(1, thread_group_size, 1);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(keys, 0);
@@ -401,6 +401,12 @@ void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
}
void Eigh::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());

View File

@@ -6,237 +6,36 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
namespace mlx::core {
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
void launch_qmm(
std::string name,
const std::vector<array>& inputs,
array& out,
int group_size,
int bits,
int D,
int O,
int B,
int N,
MTL::Size& group_dims,
MTL::Size& grid_dims,
bool batched,
bool matrix,
bool gather,
bool aligned,
bool quad,
const Stream& s) {
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
std::vector<array> copies;
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto x = ensure_row_contiguous(x_pre);
auto w = ensure_row_contiguous(w_pre);
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
int D = x.shape(-1);
int B = x.size() / D;
int O = out.shape(-1);
if (transpose_) {
// Route to the fast qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "qmv_fast", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmv kernel
else if (B < 6) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "qmv", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm_t kernel
else {
std::ostringstream kname;
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
auto type_string = get_type_string(x.dtype());
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
} else {
// Route to the qvm kernel
if (B < 4) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "qvm", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the qmm_n kernel
else {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "qmm_n", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
if ((O % bn) != 0) {
std::ostringstream msg;
msg << "[quantized_matmul] The output size should be divisible by "
<< bn << " but received " << O << ".";
throw std::runtime_error(msg.str());
}
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&B, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
}
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
// TODO: collapse batch dims
auto& batch_shape = lhs_indices.shape();
int batch_ndims = batch_shape.size();
auto& lhs_strides = lhs_indices.strides();
auto& rhs_strides = rhs_indices.strides();
// Ensure that the last two dims are row contiguous.
// TODO: Check if we really need this for x as well...
std::vector<array> copies;
@@ -266,256 +65,327 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s_strides = scales.strides();
auto& b_strides = biases.strides();
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
if (quad) {
kname << "_d_" << D;
}
if (aligned) {
kname << "_alN_" << aligned_n;
}
if (!gather) {
kname << "_batch_" << batched;
}
// Encode and dispatch kernel
std::string template_def;
if (quad) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, D, batched);
} else if (aligned && !gather) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n, batched);
} else if (!gather && !aligned) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, batched);
} else if (aligned && gather) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n);
} else {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits);
}
auto& d = metal::device(s.device);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(out, 4);
compute_encoder->setBytes(&D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
int offset = 7;
if (matrix) {
compute_encoder->setBytes(&B, sizeof(int), 7);
offset += 1;
}
if (batched || gather) {
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), offset);
set_vector_bytes(compute_encoder, x_shape, offset + 1);
set_vector_bytes(compute_encoder, x_strides, offset + 2);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), offset + 3);
set_vector_bytes(compute_encoder, w_shape, offset + 4);
set_vector_bytes(compute_encoder, w_strides, offset + 5);
set_vector_bytes(compute_encoder, s_strides, offset + 6);
set_vector_bytes(compute_encoder, b_strides, offset + 7);
}
if (gather) {
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
// TODO: collapse batch dims
auto& batch_shape = lhs_indices.shape();
int batch_ndims = batch_shape.size();
auto& lhs_strides = lhs_indices.strides();
auto& rhs_strides = rhs_indices.strides();
compute_encoder->setBytes(&batch_ndims, sizeof(int), offset + 8);
set_vector_bytes(compute_encoder, batch_shape, offset + 9);
compute_encoder.set_input_array(lhs_indices, offset + 10);
compute_encoder.set_input_array(rhs_indices, offset + 11);
set_vector_bytes(compute_encoder, lhs_strides, offset + 12);
set_vector_bytes(compute_encoder, rhs_strides, offset + 13);
}
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
}
void qvm_split_k(
const std::vector<array>& inputs,
array& out,
int group_size,
int bits,
int D,
int O,
int B,
int N,
const Stream& s) {
int split_k = D > 8192 ? 32 : 8;
int split_D = (D + split_k - 1) / split_k;
N *= split_k;
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
// Ensure that the last two dims are row contiguous.
// TODO: Check if we really need this for x as well...
std::vector<array> copies;
auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
int x_batch_ndims = x.ndim() - 2;
auto x_shape = x.shape();
auto x_strides = x.strides();
int w_batch_ndims = w.ndim() - 2;
auto w_shape = w.shape();
auto w_strides = w.strides();
auto s_strides = scales.strides();
auto b_strides = biases.strides();
// Add split_k dim with reshapes
x_shape.insert(x_shape.end() - 2, split_k);
x_shape.back() /= split_k;
x_strides.insert(x_strides.end() - 2, split_D);
x_strides[x.ndim() - 1] = split_D;
x_batch_ndims += 1;
w_shape.insert(w_shape.end() - 2, split_k);
w_shape[w.ndim() - 1] /= split_k;
w_strides.insert(w_strides.end() - 2, split_D * w.shape(-1));
w_batch_ndims += 1;
s_strides.insert(s_strides.end() - 2, split_D * scales.shape(-1));
b_strides.insert(b_strides.end() - 2, split_D * biases.shape(-1));
int final_block_size = D - (split_k - 1) * split_D;
auto& d = metal::device(s.device);
auto temp_shape = out.shape();
temp_shape.insert(temp_shape.end() - 2, split_k);
array intermediate(temp_shape, x.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qvm_split_k" << "_" << type_string << "_gs_" << group_size << "_b_"
<< bits << "_spk_" << split_k;
auto template_def = get_template_definition(
kname.str(), "qvm_split_k", type_string, group_size, bits, split_k);
// Encode and dispatch kernel
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_output_array(intermediate, 4);
compute_encoder->setBytes(&split_D, sizeof(int), 5);
compute_encoder->setBytes(&O, sizeof(int), 6);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 7);
set_vector_bytes(compute_encoder, x_shape, 8);
set_vector_bytes(compute_encoder, x_strides, 9);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, w_shape, 11);
set_vector_bytes(compute_encoder, w_strides, 12);
set_vector_bytes(compute_encoder, s_strides, 13);
set_vector_bytes(compute_encoder, b_strides, 14);
compute_encoder->setBytes(&final_block_size, sizeof(int), 15);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
int axis = intermediate.ndim() - 3;
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce,
{intermediate.shape(axis)},
{intermediate.strides(axis)});
strided_reduce_general_dispatch(
intermediate, out, "sum", plan, {axis}, compute_encoder, d, s);
}
void qmm_op(
const std::vector<array>& inputs,
array& out,
bool transpose,
int group_size,
int bits,
bool gather,
const Stream& s) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
MTL::Size group_dims;
MTL::Size grid_dims;
auto& x = inputs[0];
auto& w = inputs[1];
bool batched = !gather && (w.ndim() > 2 || !x.flags().row_contiguous);
int D = x.shape(-1);
int B = x.shape(-2);
int O = out.shape(-1);
int N = out.size() / B / O;
if (transpose_) {
// Route to the fast bs_qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// For the unbatched W case, avoid `adjust_matrix_offsets`
// for a small performance gain.
int B = (batched || gather) ? x.shape(-2) : x.size() / D;
int N = (batched || gather) ? out.size() / B / O : 1;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "bs_qmv_fast", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
std::string name = gather ? "bs_" : "";
bool matrix = false;
bool aligned = false;
bool quad = false;
if (transpose) {
if (B < 6 && (D == 128 || D == 64)) {
name += "qmv_quad";
constexpr int quads_per_simd = 8;
constexpr int results_per_quadgroup = 8;
int bo = quads_per_simd * results_per_quadgroup;
int simdgroup_size = 32;
group_dims = MTL::Size(simdgroup_size, 1, 1);
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
quad = true;
} else if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
name += "qmv_fast";
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, lhs_strides, 11);
set_vector_bytes(compute_encoder, rhs_strides, 12);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
set_vector_bytes(compute_encoder, x_shape, 14);
set_vector_bytes(compute_encoder, x_strides, 15);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
set_vector_bytes(compute_encoder, w_shape, 17);
set_vector_bytes(compute_encoder, w_strides, 18);
set_vector_bytes(compute_encoder, s_strides, 19);
set_vector_bytes(compute_encoder, b_strides, 20);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
else if (B < 6) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "bs_qmv", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size(O / bo, B, N);
} else if (B < 6) {
name += "qmv";
int bo = 8;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_input_array(x, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, lhs_strides, 11);
set_vector_bytes(compute_encoder, rhs_strides, 12);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
set_vector_bytes(compute_encoder, x_shape, 14);
set_vector_bytes(compute_encoder, x_strides, 15);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
set_vector_bytes(compute_encoder, w_shape, 17);
set_vector_bytes(compute_encoder, w_strides, 18);
set_vector_bytes(compute_encoder, s_strides, 19);
set_vector_bytes(compute_encoder, b_strides, 20);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to the bs_qmm_t
else {
std::ostringstream kname;
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
auto type_string = get_type_string(out.dtype());
kname << "bs_qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "bs_qmm_t", type_string, group_size_, bits_, aligned_n);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
} else {
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&B, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&D, sizeof(int), 9);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, batch_shape, 11);
set_vector_bytes(compute_encoder, lhs_strides, 12);
set_vector_bytes(compute_encoder, rhs_strides, 13);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
set_vector_bytes(compute_encoder, x_shape, 15);
set_vector_bytes(compute_encoder, x_strides, 16);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
set_vector_bytes(compute_encoder, w_shape, 18);
set_vector_bytes(compute_encoder, w_strides, 19);
set_vector_bytes(compute_encoder, s_strides, 20);
set_vector_bytes(compute_encoder, b_strides, 21);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
group_dims = MTL::Size(32, wn, wm);
grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
name += "qmm_t";
matrix = true;
aligned = true;
}
} else {
// Route to the bs_qvm kernel
if (B < 4) {
std::ostringstream kname;
auto type_string = get_type_string(out.dtype());
kname << "bs_qvm_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "bs_qvm", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
if (B < 4 && D >= 1024 && !gather) {
return qvm_split_k(inputs, out, group_size, bits, D, O, B, N, s);
} else if (B < 4) {
name += "qvm";
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&D, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, lhs_strides, 11);
set_vector_bytes(compute_encoder, rhs_strides, 12);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
set_vector_bytes(compute_encoder, x_shape, 14);
set_vector_bytes(compute_encoder, x_strides, 15);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
set_vector_bytes(compute_encoder, w_shape, 17);
set_vector_bytes(compute_encoder, w_strides, 18);
set_vector_bytes(compute_encoder, s_strides, 19);
set_vector_bytes(compute_encoder, b_strides, 20);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Route to bs_qmm_n
else {
std::ostringstream kname;
auto type_string = get_type_string(out.dtype());
kname << "bs_qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition(
kname.str(), "bs_qmm_n", type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
group_dims = MTL::Size(bd, 2, 1);
grid_dims = MTL::Size(O / bo, B, N);
} else {
name += "qmm_n";
int wn = 2;
int wm = 2;
int bm = 32;
int bn = 32;
int bk = 32;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
group_dims = MTL::Size(32, wn, wm);
grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
matrix = true;
if ((O % bn) != 0) {
std::ostringstream msg;
msg << "[quantized_matmul] The output size should be divisible by "
<< bn << " but received " << O << ".";
throw std::runtime_error(msg.str());
}
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(scales, 2);
compute_encoder.set_input_array(biases, 3);
compute_encoder.set_input_array(lhs_indices, 4);
compute_encoder.set_input_array(rhs_indices, 5);
compute_encoder.set_output_array(out, 6);
compute_encoder->setBytes(&B, sizeof(int), 7);
compute_encoder->setBytes(&O, sizeof(int), 8);
compute_encoder->setBytes(&D, sizeof(int), 9);
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
set_vector_bytes(compute_encoder, batch_shape, 11);
set_vector_bytes(compute_encoder, lhs_strides, 12);
set_vector_bytes(compute_encoder, rhs_strides, 13);
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
set_vector_bytes(compute_encoder, x_shape, 15);
set_vector_bytes(compute_encoder, x_strides, 16);
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
set_vector_bytes(compute_encoder, w_shape, 18);
set_vector_bytes(compute_encoder, w_strides, 19);
set_vector_bytes(compute_encoder, s_strides, 20);
set_vector_bytes(compute_encoder, b_strides, 21);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}
launch_qmm(
name,
inputs,
out,
group_size,
bits,
D,
O,
B,
N,
group_dims,
grid_dims,
batched,
matrix,
gather,
aligned,
quad,
s);
}
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
}
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream());
}
void fast::AffineQuantize::eval_gpu(
@@ -603,12 +473,7 @@ void fast::AffineQuantize::eval_gpu(
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -141,6 +141,20 @@ struct ColReduceArgs {
ndim = shape.size();
}
/**
* Create the col reduce arguments for reducing the 1st axis of the row
* contiguous intermediate array.
*/
ColReduceArgs(const array& intermediate) {
assert(intermediate.flags().row_contiguous);
reduction_size = intermediate.shape(0);
reduction_stride = intermediate.size() / reduction_size;
non_col_reductions = 1;
reduce_ndim = 0;
ndim = 0;
}
void encode(CommandEncoder& compute_encoder) {
// Push 0s to avoid encoding empty vectors.
if (reduce_ndim == 0) {
@@ -231,8 +245,10 @@ void init_reduce(
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
auto kernel = get_reduce_init_kernel(
d, "init_reduce_" + op_name + type_to_name(out), out);
std::ostringstream kname;
const std::string func_name = "init_reduce";
kname << func_name << "_" << op_name << type_to_name(out);
auto kernel = get_reduce_init_kernel(d, kname.str(), func_name, op_name, out);
size_t nthreads = out.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@@ -251,8 +267,7 @@ void all_reduce_dispatch(
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s,
std::vector<array>& copies) {
const Stream& s) {
// Set the kernel
std::ostringstream kname;
const std::string func_name = "all_reduce";
@@ -293,7 +308,7 @@ void all_reduce_dispatch(
// Allocate an intermediate tensor to hold results if needed
array intermediate({n_rows}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
copies.push_back(intermediate);
d.add_temporary(intermediate, s.index);
// 1st pass
size_t row_size = (in_size + n_rows - 1) / n_rows;
@@ -469,39 +484,11 @@ void strided_reduce_small(
// Figure out the grid dims
MTL::Size grid_dims, group_dims;
// Case 1: Small row small column
if (args.reduction_size * args.non_col_reductions < 64 &&
args.reduction_stride < 32) {
grid_dims = output_grid_for_col_reduce(out, args);
int threadgroup_size = (grid_dims.width > 128) ? 128 : grid_dims.width;
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Case 2: Long row small column
else if (args.reduction_size * args.non_col_reductions < 32) {
auto out_grid_dims = output_grid_for_col_reduce(out, args);
int threads_x =
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
int threadgroup_x = std::min(threads_x, 128);
grid_dims = MTL::Size(threads_x, out_grid_dims.width, out_grid_dims.height);
group_dims = MTL::Size(threadgroup_x, 1, 1);
}
// Case 3: Long row medium column
else {
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
int simdgroups =
(args.reduction_stride + REDUCE_N_READS - 1) / REDUCE_N_READS;
int threadgroup_size = simdgroups * 32;
auto out_grid_dims = output_grid_for_col_reduce(out, args);
grid_dims =
MTL::Size(threadgroup_size, out_grid_dims.width, out_grid_dims.height);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_small";
@@ -510,10 +497,113 @@ void strided_reduce_small(
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
const int n_reads = 4;
size_t reduction_stride_blocks =
(args.reduction_stride + n_reads - 1) / n_reads;
size_t total = args.reduction_size * args.non_col_reductions;
size_t threadgroup_x = std::min(reduction_stride_blocks, 32ul);
size_t threadgroup_y = std::min(
8ul,
std::min(kernel->maxTotalThreadsPerThreadgroup() / threadgroup_x, total));
group_dims = MTL::Size(threadgroup_x, threadgroup_y, 1);
grid_dims = output_grid_for_col_reduce(out, args);
grid_dims = MTL::Size(
(reduction_stride_blocks + threadgroup_x - 1) / threadgroup_x,
grid_dims.width,
grid_dims.height);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
args.encode(compute_encoder);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void strided_reduce_longcolumn(
const array& in,
array& out,
const std::string& op_name,
ColReduceArgs& args,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
size_t outer_blocks = 32;
if (total_reduction_size >= 32768) {
outer_blocks = 128;
}
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(outer_blocks);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Figure out the grid dims
size_t out_size = out.size();
size_t threadgroup_x = args.reduction_stride;
size_t threadgroup_y =
(args.non_col_reductions * args.reduction_size + outer_blocks - 1) /
outer_blocks;
threadgroup_y = std::min(32ul, threadgroup_y);
auto out_grid_size = output_grid_for_col_reduce(out, args);
MTL::Size grid_dims(out_grid_size.width, out_grid_size.height, outer_blocks);
MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_longcolumn";
kname << func_name << "_" << n << "_reduce_" << op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate);
second_args.reduce_shape.push_back(outer_blocks);
second_args.reduce_strides.push_back(out.size());
second_args.reduce_ndim++;
int BN = 32;
grid_dims = MTL::Size(256 * ((out.size() + BN - 1) / BN), 1, 1);
group_dims = MTL::Size(256, 1, 1);
// Set the 2nd kernel
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
op_name + type_to_name(intermediate);
kernel = get_reduce_kernel(
d,
second_kernel,
"col_reduce_looped",
op_name,
intermediate,
out,
1,
32,
32);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
@@ -532,9 +622,9 @@ void strided_reduce_looped(
// Figure out the grid dims
auto out_grid_size = output_grid_for_col_reduce(out, args);
int BN = (args.reduction_stride <= 1024) ? 32 : 128;
int BN = 32;
int BM = 1024 / BN;
int threadgroup_size = 4 * 32;
int threadgroup_size = 8 * 32;
MTL::Size grid_dims(
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
out_grid_size.width,
@@ -558,6 +648,87 @@ void strided_reduce_looped(
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_2pass(
const array& in,
array& out,
const std::string& op_name,
ColReduceArgs& args,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s) {
// Prepare the temporary accumulator
std::vector<int> intermediate_shape;
intermediate_shape.reserve(out.ndim() + 1);
intermediate_shape.push_back(32);
intermediate_shape.insert(
intermediate_shape.end(), out.shape().begin(), out.shape().end());
array intermediate(std::move(intermediate_shape), out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes()));
d.add_temporary(intermediate, s.index);
// Prepare the arguments for the kernel
args.reduce_shape.push_back(args.reduction_size);
args.reduce_strides.push_back(args.reduction_stride);
args.reduce_ndim++;
// Figure out the grid dims
size_t out_size = out.size() / args.reduction_stride;
auto out_grid_size = output_grid_for_col_reduce(out, args);
int outer_blocks = 32;
int BN = 32;
int BM = 1024 / BN;
int threadgroup_size = 8 * 32;
MTL::Size grid_dims(
threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
out_grid_size.width * outer_blocks,
out_grid_size.height);
MTL::Size group_dims(threadgroup_size, 1, 1);
// Set the kernel
int n = (args.reduce_ndim < 5) ? std::max(1, args.reduce_ndim) : 0;
std::ostringstream kname;
const std::string func_name = "col_reduce_2pass";
kname << func_name << "_" << n << "_" << BM << "_" << BN << "_reduce_"
<< op_name << type_to_name(in);
auto kernel =
get_reduce_kernel(d, kname.str(), func_name, op_name, in, out, n, BM, BN);
compute_encoder->setComputePipelineState(kernel);
// Launch
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(intermediate, 1);
args.encode(compute_encoder);
compute_encoder->setBytes(&out_size, sizeof(size_t), 11);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Make the 2nd pass arguments and grid_dims
ColReduceArgs second_args(intermediate);
second_args.reduce_shape.push_back(outer_blocks);
second_args.reduce_strides.push_back(out.size());
second_args.reduce_ndim++;
grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
// Set the 2nd kernel
const std::string second_kernel = "col_reduce_looped_1_32_32_reduce_" +
op_name + type_to_name(intermediate);
kernel = get_reduce_kernel(
d,
second_kernel,
"col_reduce_looped",
op_name,
intermediate,
out,
1,
32,
32);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(intermediate, 0);
compute_encoder.set_output_array(out, 1);
second_args.encode(compute_encoder);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void strided_reduce_general_dispatch(
const array& in,
array& out,
@@ -570,11 +741,23 @@ void strided_reduce_general_dispatch(
// Prepare the arguments for the kernel
ColReduceArgs args(in, plan, axes);
if (args.reduction_stride < 32 ||
args.reduction_size * args.non_col_reductions < 32) {
// Small column
if (args.reduction_size * args.non_col_reductions < 32) {
return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
}
// Long column but small row
if (args.reduction_stride < 32 &&
args.reduction_size * args.non_col_reductions >= 1024) {
return strided_reduce_longcolumn(
in, out, op_name, args, compute_encoder, d, s);
}
if (args.reduction_size * args.non_col_reductions > 256 &&
out.size() / 32 < 1024) {
return strided_reduce_2pass(in, out, op_name, args, compute_encoder, d, s);
}
return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s);
}
@@ -620,7 +803,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reduce
if (in.size() > 0) {
std::vector<array> copies;
ReductionPlan plan = get_reduction_plan(in, axes_);
// If it is a general reduce then copy the input to a contiguous array and
@@ -632,7 +814,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
copies.push_back(in_copy);
d.add_temporary(in_copy, s.index);
in = in_copy;
plan = get_reduction_plan(in, axes_);
}
@@ -640,7 +822,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// Reducing over everything and the data is all there no broadcasting or
// slicing etc.
if (plan.type == ContiguousAllReduce) {
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s, copies);
all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
}
// At least the last dimension is row contiguous and we are reducing over
@@ -659,13 +841,6 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
strided_reduce_general_dispatch(
in, out, op_name, plan, axes_, compute_encoder, d, s);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
}
// Nothing to reduce just initialize the output

View File

@@ -16,8 +16,7 @@ void all_reduce_dispatch(
const std::string& op_name,
CommandEncoder& compute_encoder,
metal::Device& d,
const Stream& s,
std::vector<array>& copies);
const Stream& s);
void row_reduce_general_dispatch(
const array& in,

View File

@@ -0,0 +1,99 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/resident.h"
#include "mlx/backend/metal/metal_impl.h"
namespace mlx::core::metal {
ResidencySet::ResidencySet(MTL::Device* d) {
if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {
return;
} else if (__builtin_available(macOS 15, iOS 18, *)) {
auto pool = new_scoped_memory_pool();
auto desc = MTL::ResidencySetDescriptor::alloc()->init();
NS::Error* error;
wired_set_ = d->newResidencySet(desc, &error);
desc->release();
if (!wired_set_) {
std::ostringstream msg;
msg << "[metal::Device] Unable to construct residency set.\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
}
}
void ResidencySet::insert(MTL::Allocation* buf) {
if (!wired_set_) {
return;
}
if (wired_set_->allocatedSize() + buf->allocatedSize() <= capacity_) {
wired_set_->addAllocation(buf);
wired_set_->commit();
wired_set_->requestResidency();
} else {
unwired_set_.insert(buf);
}
}
void ResidencySet::erase(MTL::Allocation* buf) {
if (!wired_set_) {
return;
}
if (auto it = unwired_set_.find(buf); it != unwired_set_.end()) {
unwired_set_.erase(it);
} else {
wired_set_->removeAllocation(buf);
wired_set_->commit();
}
}
void ResidencySet::resize(size_t size) {
if (!wired_set_) {
return;
}
if (capacity_ == size) {
return;
}
capacity_ = size;
size_t current_size = wired_set_->allocatedSize();
if (current_size < size) {
// Add unwired allocations to the set
for (auto it = unwired_set_.begin(); it != unwired_set_.end();) {
auto buf_size = (*it)->allocatedSize();
if (current_size + buf_size > size) {
it++;
} else {
current_size += buf_size;
wired_set_->addAllocation(*it);
unwired_set_.erase(it++);
}
}
wired_set_->commit();
wired_set_->requestResidency();
} else if (current_size > size) {
// Remove wired allocations until under capacity
auto allocations = wired_set_->allAllocations();
auto num_allocations = wired_set_->allocationCount();
for (int i = 0; i < num_allocations && current_size > size; ++i) {
auto buf = static_cast<const MTL::Allocation*>(allocations->object(i));
wired_set_->removeAllocation(buf);
current_size -= buf->allocatedSize();
unwired_set_.insert(buf);
}
wired_set_->commit();
}
}
ResidencySet::~ResidencySet() {
if (wired_set_) {
wired_set_->release();
}
}
} // namespace mlx::core::metal

View File

@@ -0,0 +1,32 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/device.h"
namespace mlx::core::metal {
class ResidencySet {
public:
ResidencySet(MTL::Device* d);
~ResidencySet();
ResidencySet(const ResidencySet&) = delete;
ResidencySet& operator=(const ResidencySet&) = delete;
const MTL::ResidencySet* mtl_residency_set() {
return wired_set_;
}
void insert(MTL::Allocation* buf);
void erase(MTL::Allocation* buf);
void resize(size_t size);
private:
MTL::ResidencySet* wired_set_{nullptr};
std::unordered_set<const MTL::Allocation*> unwired_set_;
size_t capacity_{0};
};
} // namespace mlx::core::metal

View File

@@ -1,20 +1,13 @@
//
// scaled_dot_product_attention.cpp
// mlx
// Copyright © 2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fast {
@@ -26,8 +19,7 @@ void sdpa_full_self_attention_metal(
const array& k,
const array& v,
const float alpha,
array& out,
std::vector<array>& temporaries) {
array& out) {
std::ostringstream kname_self_attention;
kname_self_attention << "steel_gemm_attention_";
@@ -148,130 +140,60 @@ void sdpa_full_self_attention_metal(
MTL::Size group_dims = MTL::Size(32, wm, wn);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
return;
}
void sdpa_metal(
void sdpa_vector(
const Stream& s,
metal::Device& d,
const array& q,
const array& k,
const array& v,
const array& p_lse,
const array& p_rowmaxes,
const array& o_partial,
const uint heads,
const uint tile_size,
const uint n_tiles,
const float alpha,
array& out,
std::vector<array>& temporaries) {
std::ostringstream kname_partials;
float scale) {
// Set the kernel name
std::string kname;
kname.reserve(64);
kname += "sdpa_vector_";
kname += get_type_string(q.dtype());
kname += "_";
kname += std::to_string(q.shape(-1));
kname_partials << "fast_inference_sdpa_compute_partials_";
// Compute the necessary sizes
int gqa_factor = q.shape(1) / k.shape(1);
int N = k.shape(2);
int B = q.shape(0) * q.shape(1);
size_t k_stride = k.strides()[1];
size_t v_stride = v.strides()[1];
MTL::Size group_dims(1024, 1, 1);
MTL::Size grid_dims(1, B, 1);
std::ostringstream kname_reduce;
std::string delimiter = "_";
kname_reduce << "fast_inference_sdpa_reduce_tiles" + delimiter;
for (const auto& arr : {k, v, out}) {
if (arr.dtype() != q.dtype()) {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: expected matching dtypes for q,k,v,o");
}
}
if (q.dtype() == float32) {
kname_partials << "float" + delimiter;
kname_reduce << "float";
} else if (q.dtype() == float16) {
kname_partials << "half" + delimiter;
kname_reduce << "half";
} else {
throw std::runtime_error(
"[ScaledDotProductAttention::eval_gpu]: unexpected dtype found for queries: expected either float32 or float16.");
}
std::string kname_suffix_tile_size = std::to_string(tile_size) + delimiter;
uint nsimd = 8;
std::string kname_suffix_nsimdgroups = std::to_string(nsimd);
// maximum number of splits == 128 at the moment (reserved tile registers in
// reduction kernel). this is arbitrary and could be changed in the shader.
std::string kname_suffix = kname_suffix_tile_size + kname_suffix_nsimdgroups;
kname_partials << kname_suffix;
// Get the kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname_partials.str());
auto kernel = d.get_kernel(kname);
compute_encoder->setComputePipelineState(kernel);
constexpr const uint batch = 1;
MTL::Size grid_dims = MTL::Size(heads, n_tiles, batch);
MTL::Size group_dims = MTL::Size(32, nsimd, 1);
const uint64_t KV_sequence_length = k.shape(-2);
const uint query_sequence_length = q.shape(-2);
const uint n_q_heads = q.shape(1);
const uint n_kv_heads = k.shape(1);
MLXScaledDotProductAttentionParams params{
query_sequence_length, n_q_heads, n_kv_heads, n_tiles, alpha};
compute_encoder.set_input_array(q, 0);
// Set its arguments
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder->setBytes(&KV_sequence_length, sizeof(KV_sequence_length), 3);
compute_encoder->setBytes(
&params, sizeof(MLXScaledDotProductAttentionParams), 4);
compute_encoder.set_input_array(o_partial, 5);
compute_encoder.set_input_array(p_lse, 6);
compute_encoder.set_input_array(p_rowmaxes, 7);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&gqa_factor, sizeof(int), 4);
compute_encoder->setBytes(&N, sizeof(int), 5);
compute_encoder->setBytes(&k_stride, sizeof(size_t), 6);
compute_encoder->setBytes(&v_stride, sizeof(size_t), 7);
compute_encoder->setBytes(&scale, sizeof(float), 8);
constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
// Launch
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
{
auto kernel_accum = d.get_kernel(kname_reduce.str());
compute_encoder->setComputePipelineState(kernel_accum);
compute_encoder.set_input_array(o_partial, 0);
compute_encoder.set_input_array(p_lse, 1);
compute_encoder.set_input_array(p_rowmaxes, 2);
compute_encoder->setBytes(
&params, sizeof(MLXScaledDotProductAttentionParams), 3);
compute_encoder.set_output_array(out, 4);
MTL::Size grid_dims_reduce = MTL::Size(heads, 1, batch);
MTL::Size group_dims_reduce = MTL::Size(128, 1, 1);
compute_encoder.dispatchThreadgroups(grid_dims_reduce, group_dims_reduce);
d.get_command_buffer(s.index)->addCompletedHandler(
[temporaries](MTL::CommandBuffer*) mutable { temporaries.clear(); });
return;
}
}
} // namespace
void ScaledDotProductAttention::eval_gpu(
const std::vector<array>& inputs,
array& out) {
assert(inputs.size() >= 3);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[ScaledDotProductAttention] Does not yet support non-floating point types.");
}
assert(inputs.size() == 3);
if (inputs.size() == 4) {
out = fallback_(inputs)[0];
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
@@ -279,84 +201,70 @@ void ScaledDotProductAttention::eval_gpu(
auto& k_pre = inputs[1];
auto& v_pre = inputs[2];
auto& o = out;
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> temporaries;
auto check_transpose = [&temporaries, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return arr;
} else {
std::vector<array> copies;
// Define some copy functions to ensure the layout of the inputs is as
// expected.
auto copy_unless = [&copies, &s](auto predicate, const array& arr) {
if (!predicate(arr)) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
temporaries.push_back(arr_copy);
size_t stx = arr.shape(-1);
copies.push_back(arr_copy);
return arr_copy;
} else {
return arr;
}
};
auto q = check_transpose(q_pre);
auto k = check_transpose(k_pre);
auto v = check_transpose(v_pre);
// Checks if arr is fully row contiguous
auto is_contiguous = [](const array& arr) {
return arr.flags().row_contiguous;
};
const int heads = q.shape(-3);
// Returns true if the array is row contiguous except the sequence length
// dimension that can be sliced but with step=1.
auto is_contiguous_except_seq_len = [](const array& arr) {
auto& strides = arr.strides();
auto& shape = arr.shape();
return strides[3] == 1 && strides[2] == shape[3] &&
strides[0] == strides[1] * shape[1];
};
uint query_sequence_length = q.shape(-2);
if (query_sequence_length >= 16) {
return sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, out, temporaries);
}
int tile_size = 64;
const int kv_seq_len = k.shape(-2);
if (kv_seq_len > 8000) {
tile_size = 128;
}
if (kv_seq_len > 16000) {
tile_size = 256;
}
if (kv_seq_len > 32000) {
tile_size = 512;
// Checks that the last two dims are row contiguous.
auto is_matrix_contiguous = [](const array& arr) {
auto& strides = arr.strides();
auto& shape = arr.shape();
return strides[3] == 1 && strides[2] == shape[3];
};
// We are in vector mode ie single query
if (q_pre.shape(2) == 1) {
auto q = copy_unless(is_contiguous, q_pre);
auto k = copy_unless(is_contiguous_except_seq_len, k_pre);
auto v = copy_unless(is_contiguous_except_seq_len, v_pre);
// Donate the query if possible
if (q.is_donatable()) {
o.move_shared_buffer(q);
} else {
o.set_data(allocator::malloc_or_wait(o.nbytes()));
}
sdpa_vector(s, d, q, k, v, o, scale_);
}
const int n_tiles = (kv_seq_len + tile_size - 1) / tile_size;
// Full attention mode
else {
auto q = copy_unless(is_matrix_contiguous, q_pre);
auto k = copy_unless(is_matrix_contiguous, k_pre);
auto v = copy_unless(is_matrix_contiguous, v_pre);
o.set_data(allocator::malloc_or_wait(o.nbytes()));
array o_partials(
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles * v.shape(-1)},
float32,
nullptr,
{});
o_partials.set_data(allocator::malloc_or_wait(o_partials.nbytes()));
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
}
array p_lse(
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
array p_rowmaxes(
{q.shape(-4), q.shape(-3), q.shape(-2), n_tiles}, float32, nullptr, {});
p_lse.set_data(allocator::malloc_or_wait(p_lse.nbytes()));
p_rowmaxes.set_data(allocator::malloc_or_wait(p_rowmaxes.nbytes()));
temporaries.push_back(p_lse);
temporaries.push_back(p_rowmaxes);
temporaries.push_back(o_partials);
return sdpa_metal(
s,
d,
q,
k,
v,
p_lse,
p_rowmaxes,
o_partials,
heads,
tile_size,
n_tiles,
scale_,
out,
temporaries);
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core::fast

View File

@@ -14,19 +14,27 @@ namespace mlx::core {
void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
// Ensure contiguity
std::vector<array> copies;
auto in = inputs[0];
if (!in.flags().row_contiguous) {
if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
in = arr_copy;
out.move_shared_buffer(in);
}
bool contiguous = in.strides()[axis_] == 1;
@@ -61,7 +69,8 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
size_t size = in.shape(axis_);
compute_encoder->setBytes(&size, sizeof(size_t), 2);
@@ -70,7 +79,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
constexpr int simd_size = 32;
int elements_per_simd = n_reads * simd_size;
int thread_groups = in.size() / size;
int thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (size <= n_reads * 1024) {
thread_group_size =
@@ -82,38 +90,45 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
thread_group_size = std::min(
thread_group_size,
static_cast<int>(kernel->maxTotalThreadsPerThreadgroup()));
MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1);
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
auto tmp_grid_dims =
get_2d_grid_dims(in.shape(), in.strides(), /** divisor= */ size);
MTL::Size grid_dims(
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
MTL::Size group_dims(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_output_array(out, 1);
size_t size = in.shape(axis_);
size_t stride = in.strides()[axis_];
int bm = 32;
int bn = 32;
size_t stride_blocks = (stride + bn - 1) / bn;
compute_encoder->setBytes(&size, sizeof(size_t), 2);
compute_encoder->setBytes(&stride, sizeof(size_t), 3);
compute_encoder->setBytes(&stride_blocks, sizeof(size_t), 4);
// Compute the thread grid
int n_reads = (in.itemsize() <= 4) ? 4 : 2;
int tile_x = 32;
int tile_y = 32;
int elements_per_tile_x = tile_x * n_reads;
int grid_y = in.size() / size / stride;
int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x;
MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1);
MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1);
int n_simdgroups = bn / n_reads;
int thread_group_size = n_simdgroups * 32;
auto tmp_grid_dims = get_2d_grid_dims(
in.shape(), in.strides(), /** divisor= */ size * stride);
if (tmp_grid_dims.width * stride_blocks <= UINT_MAX) {
tmp_grid_dims.width *= stride_blocks;
} else {
tmp_grid_dims.height *= stride_blocks;
}
MTL::Size grid_dims(
thread_group_size, tmp_grid_dims.width, tmp_grid_dims.height);
MTL::Size group_dims(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (!copies.empty()) {
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -88,12 +88,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -252,11 +252,7 @@ void multi_block_sort(
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
s);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
d.add_temporaries(std::move(copies), s.index);
}
void gpu_merge_sort(

View File

@@ -38,8 +38,7 @@ void ternary_op_gpu_inplace(
bool use_2d = out.data_size() > UINT_MAX;
auto ndim = shape.size();
int work_per_thread =
(topt == TernaryOpType::General && shape[ndim - 1] > 4) ? 4 : 1;
int work_per_thread = (topt == TernaryOpType::General) ? 4 : 1;
std::string kernel_name;
{
std::ostringstream kname;
@@ -73,6 +72,7 @@ void ternary_op_gpu_inplace(
compute_encoder.set_input_array(donate_c ? out : c, 2);
compute_encoder.set_output_array(out, 3);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (topt == TernaryOpType::General) {
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
@@ -94,7 +94,6 @@ void ternary_op_gpu_inplace(
compute_encoder->setBytes(strides_c.data(), ndim * sizeof(size_t), 6);
}
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::ternary] Must use 1024 sized block");
}
@@ -104,13 +103,12 @@ void ternary_op_gpu_inplace(
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -35,7 +35,7 @@ void unary_op_gpu_inplace(
};
auto [shape, strides] = maybe_collapse();
int ndim = shape.size();
int work_per_thread = (!contig && shape[ndim - 1] > 4) ? 4 : 1;
int work_per_thread = !contig ? 4 : 1;
size_t nthreads = contig ? in.data_size() : in.size();
bool use_2d = nthreads > UINT32_MAX;
std::string kernel_name;
@@ -44,12 +44,10 @@ void unary_op_gpu_inplace(
} else {
kernel_name = (work_per_thread == 4 ? "gn4" : "g");
}
kernel_name += "_" + op + type_to_name(out);
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
kernel_name += "_" + op + type_to_name(in) + type_to_name(out);
auto kernel = get_unary_kernel(d, kernel_name, in.dtype(), out.dtype(), op);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(
@@ -75,6 +73,8 @@ void unary_op_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
@@ -124,11 +124,13 @@ UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Imag)
UNARY_GPU(Log1p)
UNARY_GPU(LogicalNot)
UNARY_GPU(Floor)
UNARY_GPU(Ceil)
UNARY_GPU(Negative)
UNARY_GPU(Real)
UNARY_GPU(Sigmoid)
UNARY_GPU(Sign)
UNARY_GPU(Sin)

View File

@@ -52,7 +52,7 @@ std::string type_to_name(const array& a) {
return tname;
}
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
@@ -76,7 +76,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
pows[2]++;
sum++;
}
if (sum == presum || sum == 10) {
if (sum == presum || sum == pow2) {
break;
}
}
@@ -103,6 +103,54 @@ MTL::Size get_2d_grid_dims(
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
return MTL::Size(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
MTL::Size get_2d_grid_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t divisor) {
// Compute the 2d grid dimensions such that the total size of the grid is
// divided by divisor.
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
// No need to add this shape we can just remove it from the divisor.
if (divisor % shape[i] == 0) {
divisor /= shape[i];
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
if (divisor > 1) {
if (grid_x % divisor == 0) {
grid_x /= divisor;
divisor = 1;
} else if (grid_y % divisor == 0) {
grid_y /= divisor;
divisor = 1;
}
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX || divisor > 1) {
throw std::runtime_error("Unable to safely factor shape.");
}
if (grid_y > grid_x) {
std::swap(grid_x, grid_y);
}
return MTL::Size(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}

Some files were not shown because too many files have changed in this diff Show More