Compare commits

..

47 Commits

Author SHA1 Message Date
Awni Hannun
9231617eb3 Move to nanobind v2 (#1316) 2024-08-08 17:17:46 -07:00
Alex Barron
32668a7317 CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv (#1307)
* add cholesky inv + tri inv

* always run tri_inv on cpu

* consistent naming
2024-08-08 15:18:02 -07:00
Angelos Katharopoulos
780c197f95 Fix test tolerance and patch bump (#1315) 2024-08-08 14:51:09 -07:00
Angelos Katharopoulos
eb8819e91e Revert variance to be numerically stable (#1314) 2024-08-08 13:35:02 -07:00
Awni Hannun
30bbea2f08 Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes

* some cleanup

* add utils

* fix

* fix 2

* more cleaning

* fix

* remove unused mps matmul support

* one more nit

* revert
2024-08-07 13:38:07 -07:00
Alex Barron
635ccd9e25 Add "edge" mode to mx.pad (#1309)
* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
2024-08-06 11:23:10 -07:00
nicolov
8c9f0278b9 Add vmap to scatter (#1200)
* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-05 20:12:27 -07:00
Awni Hannun
58d0e199e1 add bfloat conv for windograd (#1306)
* add bfloat conv for windograd

* accumulate in fp32

* accumulate in fp32

* accumulate in bf16
2024-08-05 15:51:13 -07:00
Awni Hannun
10b5835501 fix creating array from bf16 tensors in jax / torch (#1305) 2024-08-01 16:20:51 -07:00
Awni Hannun
6c8dd307eb faster group norm (#1304) 2024-08-01 12:49:23 -07:00
Awni Hannun
43ffdab172 fix rope and random (#1301)
* fix rope and random

* comment
2024-07-31 16:18:25 -07:00
Awni Hannun
40b6d67333 Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops

* fix bug

* fix all of copy
2024-07-30 17:18:39 -07:00
Alex Barron
c52d1600f0 Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize

* add full quantize kernel

* fused kernel with scale/bias computation

* fix docstring

* fix no jit error

* fix test

* test fix

* reduce fast api to only affine_quantize
2024-07-29 15:11:38 -07:00
Awni Hannun
aa1d6cadad Fix docs latex build and nits (#1297)
* fix docs latex build and nits

* fix stub gen and try to clean up building
2024-07-29 11:44:06 -07:00
Atakan Tekparmak
6e06e3a904 feat: Added "tanh" option to GELU approximation (#1268) 2024-07-28 09:07:56 +02:00
Yaroslav
8cfb9fc0b8 Update requirements.txt (#1291) 2024-07-26 12:59:52 -07:00
Awni Hannun
7b456fd2c0 Array api (#1289)
* some updates for numpy 2.0 and array api

* some updates for numpy 2.0 and array api

* fix array api doc
2024-07-26 10:40:49 -07:00
Awni Hannun
e9e53856d2 patch bump (#1287) 2024-07-25 11:42:09 -07:00
Anton Belov
5029894662 [Issue #1187] Add nan_to_num function initial attempt (#1247)
* initial attempt, working with wrong types

* not compiling; mx.float16 and mx.bfloat16 tests added

* fix nan to num

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-25 09:57:37 -07:00
Awni Hannun
baf9fa5f42 Einsum (#1269)
* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-07-25 09:36:44 -07:00
Jagrit Digani
7f914365fd Fix GPU sort for large arrays (#1285)
* Fix GPU sort for large arrays
2024-07-24 14:37:10 -07:00
Paul Paczuski
ebd7135b50 Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)
* Improve stability of BCE loss calculation

* Standardize comment

* Apply formatting with black via pre-commit

* Add usage recommendation to docstring

* Update python/mlx/nn/losses.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-07-24 08:38:22 -07:00
fgranqvist
50eff6a10a Implement sampling from laplace distribution. (#1279) 2024-07-24 15:15:37 +02:00
Alex Barron
c34a5ae7f7 Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-07-23 14:54:43 -07:00
Awni Hannun
e2aa6ec8ae some fixes (#1281) 2024-07-23 11:49:05 -07:00
toji
6768c6a54a Adding missing type hints (#1243)
* added type hints for `run`, `tree_map` and `tree_map_with_path`

* fix lint

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-23 07:29:38 -07:00
Tim Gymnich
6307d166eb Fix overflow / underflow handling for expm1f (#1278)
* Fix overflow / underflow handling for expm1f

* update tests
2024-07-23 07:29:06 -07:00
Awni Hannun
1fba87b0df Fix leak with multi-output primitives (#1274)
* fix leak with multi-output primitives

* hopefully an actual fix
2024-07-23 06:34:18 -07:00
Awni Hannun
df124e018a fix gguf (#1273)
* fix gguf

* comment
2024-07-18 07:35:35 -07:00
Cheng
2f83d6e4b7 Do not release buffers on exit (#1142) 2024-07-15 15:12:24 -07:00
Feng Shijie
987785d8d7 Fix typo and missing header (#1266) 2024-07-15 08:20:24 -07:00
Awni Hannun
8c01a7893b minor fix in optimizer + docs (#1264) 2024-07-12 12:18:02 -07:00
Awni Hannun
218047c75a docs fixes (#1263) 2024-07-11 15:59:07 -07:00
Alex Barron
d0da74209b version bump (#1260) 2024-07-11 11:17:55 -07:00
Angelos Katharopoulos
5c1fa64fb0 Custom transforms (#1246) 2024-07-10 18:00:01 -07:00
Alex Barron
a3c287354f Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Angelos Katharopoulos
03cf033f82 Fix reshape copy bug (#1253) 2024-07-07 21:37:00 -07:00
Alex Barron
bdb36c9a63 add zero vjps for bitwise ops and gather w.r.t. index (#1256) 2024-07-07 21:34:59 -07:00
Awni Hannun
20bb301195 CPU binary reduction + Nits (#1242)
* very minor nits

* reduce binary

* fix test
2024-06-28 13:50:42 -07:00
Awni Hannun
d6383a1c6a version bump (#1239) 2024-06-27 10:43:13 -07:00
Angelos Katharopoulos
b05bcfd27f Fixes segfault when compiling checkpointed functions (#1235) 2024-06-26 16:14:45 -07:00
Alex Barron
2615660e62 Fix strided sort bug (#1236)
* Use output strides in sort kernel

* fix zero strides bug
2024-06-26 14:32:11 -07:00
Awni Hannun
5b0af4cdb1 fix donation condition for compilation (#1237) 2024-06-26 09:04:05 -07:00
Jagrit Digani
8c2e15e6c8 Accelerate import updates for iOS (#1227)
* Update veclib and bnns includes to #include <Accelerate/Accelerate.h> for compatibility with ios

* Mark float literals in softmax.cpp to be float16_t for errors in ios

* Add arm neon vector operation guards

* Redirect to common backend for consistency
2024-06-26 09:01:50 -07:00
Awni Hannun
56c8a33439 Get metal version from xcode (#1228)
* get metal version from xcode

* typo

* fix
2024-06-26 07:02:11 -07:00
David Koski
4eef1e8a3e fix typo (#1215) 2024-06-24 13:36:35 -07:00
Alex Barron
95d11bda06 Fix NumPy 2.0 pickle test (#1221)
* fix numpy version <2 temporarily

* typo

* better fix

* Fix just for bfloat16

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-23 05:47:22 -07:00
157 changed files with 7281 additions and 3195 deletions

View File

@@ -10,13 +10,14 @@ MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.15.1)
set(MLX_VERSION 0.16.2)
endif()
# --------------------- Processor tests -------------------------
@@ -83,18 +83,17 @@ elseif (MLX_BUILD_METAL)
OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY)
if (${MACOS_VERSION} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
if (${MACOS_VERSION} GREATER_EQUAL 15.0)
set(MLX_METAL_VERSION METAL_3_2)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(MLX_METAL_VERSION METAL_3_1)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(MLX_METAL_VERSION METAL_3_0)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
# Get the metal version
execute_process(
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION
COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(
metal_cpp
@@ -113,7 +112,7 @@ elseif (MLX_BUILD_METAL)
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions(${MLX_METAL_VERSION})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif()
if (MLX_BUILD_CPU)
@@ -170,11 +169,18 @@ if (MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
COMMAND_ERROR_IS_FATAL ANY
ERROR_QUIET
)
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif (MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING
"MPI found but mpirun is not available. Building without MPI."
)
else()
set(MPI_FOUND FALSE)
message(
WARNING
"MPI which is not OpenMPI found. Building without MPI."

View File

@@ -0,0 +1,84 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
import numpy as np
def timeit(fn, its=100, args=[]):
for _ in range(5):
fn(*args)
tic = time.perf_counter()
for _ in range(its):
fn(*args)
toc = time.perf_counter()
return 1e3 * (toc - tic) / its
def time_little_einsum_path():
subscripts = "ik,kj->ij"
x = mx.ones((32, 32))
y = mx.ones((32, 32))
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
x = np.array(x)
y = np.array(y)
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
print("Timing little einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_big_einsum_path():
chars = list("abcdefgh")
char_to_dim = {c: v for v, c in enumerate(chars)}
num_inputs = 10
inputs = []
subscripts = []
for _ in range(num_inputs):
subscript = np.random.choice(chars, size=5, replace=False).tolist()
subscripts.append("".join(subscript))
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
subscripts = ",".join(subscripts)
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
inputs = [mx.array(x) for x in inputs]
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
print("Timing big einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_attention():
def regular_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
scores = mx.softmax(scores, axis=-1)
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
mx.eval(output)
def einsum_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
scores = mx.softmax(scores, axis=-1)
output = mx.einsum("ijtu,iujk->itjk", scores, values)
mx.eval(output)
x = mx.random.uniform(shape=(8, 512, 32, 128))
regular_time = timeit(regular_attention, args=(x,))
ein_time = timeit(einsum_attention, args=(x,))
print("Timing einsum attention...")
print(f"Regular ... {regular_time:.3f} ms")
print(f"Einsum ... {ein_time:.3f} ms")
if __name__ == "__main__":
time_little_einsum_path()
time_big_einsum_path()
time_attention()

View File

@@ -0,0 +1,70 @@
import argparse
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def had(x):
y = mx.hadamard_transform(x)
mx.eval(y)
def copy(x):
y = x + 1.0
mx.eval(y)
def run(dtype):
system_size = 2**26
outputs = {}
for test_fn in (had, copy):
for m in [1, 12, 20, 28]:
if test_fn == copy:
key = "copy"
elif m == 1:
key = "had_2^k"
else:
key = "had_m*2^k"
outputs.setdefault(key, {})
for k in range(7, 14):
n = m * 2**k
if n > 2**15:
continue
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
x = mx.array(x_np)
runtime_ms = measure_runtime(test_fn, x=x)
bytes_per_gb = 1e9
ms_per_s = 1e3
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
bandwidth_gb = (
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
)
print(n, bandwidth_gb)
outputs[key][n] = bandwidth_gb
colors = {
"copy": "black",
"had_2^k": "steelblue",
"had_m*2^k": "skyblue",
}
for key, output in outputs.items():
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"bench_{dtype.__name__}.png")
plt.clf()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
dtype = np.float16 if args.fp16 else np.float32
run(dtype)

View File

@@ -1,3 +1,4 @@
sphinx
breathe
sphinx-book-theme
mlx

View File

@@ -83,3 +83,15 @@ def setup(app):
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
latex_elements = {
"preamble": r"""
\usepackage{enumitem}
\setlistdepth{5}
\setlist[itemize,1]{label=$\bullet$}
\setlist[itemize,2]{label=$\bullet$}
\setlist[itemize,3]{label=$\bullet$}
\setlist[itemize,4]{label=$\bullet$}
\setlist[itemize,5]{label=$\bullet$}
\renewlist{itemize}{itemize}{5}
""",
}

View File

@@ -486,9 +486,8 @@ below.
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");

View File

@@ -15,7 +15,7 @@ module to concisely define the model architecture.
Attention layer
^^^^^^^^^^^^^^^^
We will start with the llama attention layer which notably uses the RoPE
We will start with the Llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.

View File

@@ -64,7 +64,7 @@ set:
Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as `mnist`.
we will import as ``mnist``.
.. code-block:: python

View File

@@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
.. code-block:: shell
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
Then simply build and install MLX using pip:
.. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
For developing use an editable install:
For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e ".[dev]"
To make sure the install is working run the tests with:
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext -j --inplace
Run the tests with:
.. code-block:: shell
pip install ".[testing]"
python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your IDE:
Optional: Install stubs to enable auto completions and type checking from your
IDE:
.. code-block:: shell
pip install ".[dev]"
python setup.py generate_stubs
C++ API
@@ -195,7 +195,7 @@ GGUF, you can do:
.. code-block:: shell
cmake ..
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \

View File

@@ -24,6 +24,7 @@ Array
array.any
array.argmax
array.argmin
array.conj
array.cos
array.cummax
array.cummin
@@ -57,3 +58,4 @@ Array
array.transpose
array.T
array.var
array.view

View File

@@ -9,7 +9,9 @@ Linear Algebra
:toctree: _autosummary
inv
tri_inv
norm
cholesky
cholesky_inv
qr
svd

View File

@@ -57,6 +57,8 @@ Operations
diagonal
divide
divmod
einsum
einsum_path
equal
erf
erfinv
@@ -72,6 +74,7 @@ Operations
gather_qmm
greater
greater_equal
hadamard_transform
identity
inner
isclose
@@ -103,6 +106,7 @@ Operations
minimum
moveaxis
multiply
nan_to_num
negative
not_equal
ones

View File

@@ -31,6 +31,41 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
Saving and Loading
------------------
To serialize an optimizer, save its state. To load an optimizer, load and set
the saved state. Here's a simple example:
.. code-block:: python
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim
optimizer = optim.Adam(learning_rate=1e-2)
# Perform some updates with the optimizer
model = {"w" : mx.zeros((5, 5))}
grads = {"w" : mx.ones((5, 5))}
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
parameters are not. A good rule of thumb is if the parameter can be scheduled
then it will be included in the optimizer state.
.. toctree::
optimizers/optimizer

View File

@@ -44,3 +44,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
split
truncated_normal
uniform
laplace

View File

@@ -10,6 +10,7 @@ Transforms
eval
compile
custom_function
disable_compile
enable_compile
grad

View File

@@ -249,9 +249,8 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
mlx>=0.16.2
nanobind==2.0

View File

@@ -6,6 +6,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp

View File

@@ -17,6 +17,10 @@ bool in_tracing() {
return detail::InTracing::in_tracing();
}
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@@ -102,7 +106,7 @@ void array::eval() {
}
bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing();
return array_desc_->is_tracer && in_tracing() || retain_graph();
}
void array::set_data(allocator::Buffer buffer, deleter_t d) {
@@ -171,10 +175,11 @@ array::~array() {
return;
}
// Ignore arrays that will be detached
if (status() != array::Status::unscheduled) {
// Ignore arrays that might be detached during eval
if (status() == array::Status::scheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;

View File

@@ -1,9 +1,9 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"

View File

@@ -2,8 +2,7 @@
#include <cassert>
#include <vecLib/BNNS/bnns.h>
#include <vecLib/cblas_new.h>
#include <Accelerate/Accelerate.h>
#include "mlx/backend/accelerate/utils.h"
#include "mlx/backend/common/copy.h"

View File

@@ -3,8 +3,7 @@
#include <cassert>
#include <cmath>
#include <vecLib/vDSP.h>
#include <vecLib/vForce.h>
#include <Accelerate/Accelerate.h>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
@@ -37,7 +36,7 @@ DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
@@ -51,6 +50,7 @@ DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
@@ -102,7 +102,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@@ -117,7 +117,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary(
binary_op<int>(
a,
b,
out,
@@ -132,7 +132,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x + y; });
eval(inputs, out);
}
}
@@ -287,7 +287,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == int32) {
binary(
binary_op<int>(
a,
b,
out,
@@ -300,7 +300,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
});
} else if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@@ -315,7 +315,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x / y; });
eval(inputs, out);
}
}
@@ -326,12 +326,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
eval(inputs, out);
}
}
@@ -393,12 +389,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
eval(inputs, out);
}
}
@@ -408,7 +400,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@@ -423,7 +415,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else {
binary(a, b, out, [](auto x, auto y) { return x * y; });
eval(inputs, out);
}
}
@@ -434,7 +426,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else {
unary(in, out, [](auto x) { return -x; });
eval(inputs, out);
}
}
@@ -521,7 +513,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
unary(in, out, [](auto x) { return x * x; });
eval(inputs, out);
}
}
@@ -547,7 +539,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary(
binary_op<float>(
a,
b,
out,
@@ -565,7 +557,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary(
binary_op<int>(
a,
b,
out,
@@ -577,7 +569,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
},
UseDefaultBinaryOp());
} else {
binary(a, b, out, [](auto x, auto y) { return x - y; });
eval(inputs, out);
}
}

View File

@@ -2,8 +2,8 @@
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h"

View File

@@ -3,7 +3,10 @@
#include <cassert>
#include <limits>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_neon.h>
#endif
#include <simd/math.h>
#include <simd/vector.h>
@@ -53,25 +56,26 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
return (*(simd_float16*)&epart) * x;
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/**
* The ARM neon equivalent of the fast exp above.
*/
inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(1.535336188319500e-4f);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
// generate 2**ipart in the floating point representation using integer
// bitshifting
@@ -107,53 +111,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
return vget_lane_f16(y, 0);
}
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
}
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
@@ -201,6 +158,55 @@ struct NeonFp16SimdOps {
}
};
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
}
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
@@ -362,12 +368,16 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
eval(inputs, out); // Redirect to common backend for consistency
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
}
break;
case bfloat16:

View File

@@ -1,8 +1,8 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <vecLib/BNNS/bnns.h>
#include <Accelerate/Accelerate.h>
#include "mlx/dtype.h"
namespace mlx::core {

View File

@@ -42,6 +42,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp

View File

@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(inputs[0]);
}
void CustomVJP::eval(
void CustomTransforms::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());

View File

@@ -205,8 +205,8 @@ void compiled_allocate_outputs(
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(

View File

@@ -4,6 +4,7 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@@ -142,29 +143,31 @@ void copy_general(
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
switch (src.ndim()) {
auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides});
switch (new_shape.size()) {
case 1:
copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
src, dst, new_shape, new_strides[0], i_offset);
return;
case 2:
copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
src, dst, new_shape, new_strides[0], i_offset);
return;
case 3:
copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
src, dst, new_shape, new_strides[0], i_offset);
return;
case 4:
copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset);
src, dst, new_shape, new_strides[0], i_offset);
return;
}
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
@@ -195,10 +198,10 @@ inline void copy_general_general_dims(
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
int64_t i_offset,
int64_t o_offset) {
if constexpr (D > 1) {
int axis = src.ndim() - D;
int axis = data_shape.size() - D;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
@@ -209,7 +212,7 @@ inline void copy_general_general_dims(
o_offset += stride_dst;
}
} else {
int axis = src.ndim() - 1;
int axis = data_shape.size() - 1;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
@@ -230,38 +233,76 @@ void copy_general_general(
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
stride_t i_offset,
stride_t o_offset) {
switch (src.ndim()) {
int64_t i_offset,
int64_t o_offset) {
auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides, o_strides});
switch (new_shape.size()) {
case 1:
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 2:
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 3:
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 4:
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
case 5:
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return;
}
int size = std::accumulate(
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
new_shape.end() - 5, new_shape.end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) {
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]);
stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]);
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
src_offset,
dst_offset);
}
}
@@ -444,8 +485,17 @@ void copy_inplace(
}
}
template <>
void copy_inplace<int64_t>(
template void copy_inplace<size_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
@@ -453,24 +503,6 @@ void copy_inplace<int64_t>(
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
CopyType ctype);
} // namespace mlx::core

View File

@@ -52,7 +52,7 @@ DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
@@ -68,6 +68,7 @@ DEFAULT(Full)
DEFAULT(Gather)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)

View File

@@ -0,0 +1,107 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/primitives.h"
namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(array& out, int n, int m, float scale) {
for (int b = 0; b < out.size() / n; b++) {
size_t loc = b * n;
T* data_ptr = out.data<T>() + loc;
int h = 1;
int n_over_2 = n / 2;
while (h < n) {
for (int i = 0; i < n / 2; i++) {
int k = i & (h - 1);
int j = ((i - k) << 1) + k;
float x = *(data_ptr + j);
float y = *(data_ptr + j + h);
*(data_ptr + j) = x + y;
*(data_ptr + j + h) = x - y;
if (h == n_over_2) {
*(data_ptr + j) *= scale;
*(data_ptr + j + h) *= scale;
}
}
h <<= 1;
}
}
}
// m component
template <typename T>
void hadamard_m(array& out, int n, int m, float scale) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
auto end = matrix.find('\n', start);
std::vector<bool> hmat_vec;
while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
hmat_vec.push_back(row[i] == '+');
}
start = end + 1;
end = matrix.find('\n', start);
}
for (int b = 0; b < out.size() / m / n; b++) {
size_t loc = b * n * m;
T* data_ptr = out.data<T>() + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);
for (int j = 0; j < m; j++) {
for (int k = 0; k < m; k++) {
float x = *(data_ptr + i + k * n);
if (hmat_vec[k + j * m]) {
out[j] += x;
} else {
out[j] -= x;
}
}
}
for (int j = 0; j < m; j++) {
*(data_ptr + i + j * n) = out[j] * scale;
}
}
}
}
template <typename T>
void hadamard(array& out, int n, int m, float scale) {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out, n, m, n_scale);
if (m > 1) {
hadamard_m<T>(out, n, m, scale);
}
}
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Copy input to output
copy(in, out, CopyType::General);
int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) {
case float32:
return hadamard<float>(out, n, m, scale_);
case float16:
return hadamard<float16_t>(out, n, m, scale_);
case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_);
default:
throw std::invalid_argument("[hadamard] Unsupported type.");
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,105 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <map>
#include "mlx/utils.h"
namespace mlx::core {
// From http://neilsloane.com/hadamard/
constexpr std::string_view h12 = R"(
+-++++++++++
--+-+-+-+-+-
+++-++----++
+---+--+-++-
+++++-++----
+-+---+--+-+
++--+++-++--
+--++---+--+
++----+++-++
+--+-++---+-
++++----+++-
+-+--+-++---
)";
constexpr std::string_view h20 = R"(
+----+----++--++-++-
-+----+---+++---+-++
--+----+---+++-+-+-+
---+----+---+++++-+-
----+----++--++-++-+
-+++++-----+--+++--+
+-+++-+---+-+--+++--
++-++--+---+-+--+++-
+++-+---+---+-+--+++
++++-----++--+-+--++
--++-+-++-+-----++++
---++-+-++-+---+-+++
+---++-+-+--+--++-++
++---++-+----+-+++-+
-++---++-+----+++++-
-+--+--++-+----+----
+-+-----++-+----+---
-+-+-+---+--+----+--
--+-+++------+----+-
+--+--++------+----+
)";
constexpr std::string_view h28 = R"(
+------++----++-+--+-+--++--
-+-----+++-----+-+--+-+--++-
--+-----+++---+-+-+----+--++
---+-----+++---+-+-+-+--+--+
----+-----+++---+-+-+++--+--
-----+-----++++--+-+--++--+-
------++----++-+--+-+--++--+
--++++-+-------++--+++-+--+-
---++++-+-----+-++--+-+-+--+
+---+++--+----++-++--+-+-+--
++---++---+----++-++--+-+-+-
+++---+----+----++-++--+-+-+
++++--------+-+--++-++--+-+-
-++++--------+++--++--+--+-+
-+-++-++--++--+--------++++-
+-+-++--+--++--+--------++++
-+-+-++--+--++--+----+---+++
+-+-+-++--+--+---+---++---++
++-+-+-++--+------+--+++---+
-++-+-+-++--+------+-++++---
+-++-+---++--+------+-++++--
-++--++-+-++-+++----++------
+-++--++-+-++-+++-----+-----
++-++---+-+-++-+++-----+----
-++-++-+-+-+-+--+++-----+---
--++-++++-+-+----+++-----+--
+--++-+-++-+-+----+++-----+-
++--++-+-++-+-+----++------+
)";
inline const std::map<int, std::string_view> hadamard_matrices() {
return {{12, h12}, {20, h20}, {28, h28}};
}
inline std::pair<int, int> decompose_hadamard(int n) {
// n = m*2^k
int m = 1;
if (!is_power_of_2(n)) {
auto h_matrices = hadamard_matrices();
for (auto [factor, _] : h_matrices) {
if (n % factor == 0) {
m = factor;
n /= factor;
break;
}
}
if (m == 1) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
return {n, m};
}
} // namespace mlx::core

View File

@@ -10,9 +10,106 @@
#include <lapack.h>
#endif
// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1),
/* diag_len = */ static_cast<size_t>(1));
#else
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
namespace mlx::core {
void inverse_impl(const array& a, array& inv) {
void general_inv(array& inv, int N, int i) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void tri_inv(array& inv, int N, int i, bool upper) {
const char uplo = upper ? 'L' : 'U';
const char diag = 'N';
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: triangular inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
// Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see
// https://math.stackexchange.com/a/340234):
@@ -24,63 +121,11 @@ void inverse_impl(const array& a, array& inv) {
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
for (int i = 0; i < num_matrices; i++) {
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch =
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
if (tri) {
tri_inv(inv, N, i, upper);
} else {
general_inv(inv, N, i);
}
}
}
@@ -89,7 +134,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32.");
}
inverse_impl(inputs[0], output);
inverse_impl(inputs[0], output, tri_, upper_);
}
} // namespace mlx::core

View File

@@ -405,7 +405,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto out_strides = make_contiguous_strides<size_t>(in.shape());
copy_inplace<size_t>(
in,
out,
in.shape(),
in.strides(),
out_strides,
0,
0,
CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}

View File

@@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
size_t axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
// Perform sorting in place
for (int i = 0; i < n_rows; i++) {
@@ -143,34 +143,42 @@ void argsort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto in_remaining_strides = in.strides();
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis);
// Perform sorting
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
const T* data_ptr = in.data<T>() + loc;
IdxT* idx_ptr = out.data<IdxT>() + loc;
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
const T* data_ptr = in.data<T>() + in_loc;
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
StridedIterator st_(idx_ptr, axis_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, axis_stride, 0);
StridedIterator ed(idx_ptr, axis_stride, axis_size);
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride];
auto v2 = data_ptr[b * axis_stride];
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}

View File

@@ -29,6 +29,15 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides());
}
template <typename stride_t>
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<stride_t> strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i];
}
return strides;
}
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.

View File

@@ -18,7 +18,7 @@ function(make_jit_source SRC_FILE)
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
${SRC_FILE}
"-D${MLX_METAL_VERSION}"
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/${SRC_FILE}.h
${ARGN}
@@ -52,6 +52,7 @@ make_jit_source(
)
make_jit_source(scatter)
make_jit_source(gather)
make_jit_source(hadamard)
if (MLX_METAL_JIT)
target_sources(
@@ -113,6 +114,7 @@ if (MLX_METAL_JIT)
kernels/steel/conv/loaders/loader_general.h
)
make_jit_source(quantized)
make_jit_source(gemv_masked)
else()
target_sources(
mlx
@@ -132,6 +134,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
@@ -147,6 +150,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
)
if (NOT MLX_METAL_PATH)

View File

@@ -242,8 +242,17 @@ void MetalAllocator::free(Buffer buffer) {
}
MetalAllocator& allocator() {
static MetalAllocator allocator_;
return allocator_;
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
// not be called on exit and all the buffers will be leaked. This is necessary
// because releasing buffers can take more than 30sec when the program holds a
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
// users when exiting.
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
// when applying this pattern to more places, or when introducing sanitizers
// to MLX.
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
static MetalAllocator* allocator_ = new MetalAllocator;
return *allocator_;
}
size_t set_cache_limit(size_t limit) {

View File

@@ -21,10 +21,43 @@ namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const array& a,
bool use_2d,
int ndim) {
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << (use_2d ? "sv2" : "sv");
break;
case BinaryOpType::VectorScalar:
kname << (use_2d ? "vs2" : "vs");
break;
case BinaryOpType::VectorVector:
kname << (use_2d ? "vv2" : "vv");
break;
case BinaryOpType::General:
kname << "g";
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << ndim;
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
return kname.str();
}
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op,
const std::string& op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
@@ -41,35 +74,8 @@ void binary_op_gpu_inplace(
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto& d = metal::device(s.device);
auto kernel =
@@ -117,9 +123,11 @@ void binary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
@@ -132,7 +140,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op,
const std::string& op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -146,7 +154,7 @@ void binary_op_gpu(
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op) {
const std::string& op) {
auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s);
}
@@ -154,7 +162,7 @@ void binary_op_gpu(
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const std::string& op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
@@ -169,35 +177,8 @@ void binary_op_gpu_inplace(
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
@@ -237,10 +218,11 @@ void binary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads =
bopt == BinaryOpType::General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
@@ -253,7 +235,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const std::string& op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
@@ -266,7 +248,7 @@ void binary_op_gpu(
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
const std::string& op) {
auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s);
}

View File

@@ -9,25 +9,25 @@ namespace mlx::core {
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op,
const std::string& op,
const Stream& s);
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const std::string& op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op,
const std::string& op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const std::string& op,
const Stream& s);
} // namespace mlx::core

View File

@@ -64,16 +64,17 @@ void copy_gpu_inplace(
auto& strides_in_ = strides[0];
auto& strides_out_ = strides[1];
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name;
{
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << "s";
kname << (use_2d ? "s2" : "s");
break;
case CopyType::Vector:
kname << "v";
kname << (use_2d ? "v2" : "v");
break;
case CopyType::General:
kname << "g";
@@ -139,7 +140,8 @@ void copy_gpu_inplace(
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;

View File

@@ -14,7 +14,6 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem;
@@ -30,15 +29,29 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() {
#if defined METAL_3_2
#if (MLX_METAL_VERSION >= 320)
return MTL::LanguageVersion3_2;
#elif defined METAL_3_1
#elif (MLX_METAL_VERSION >= 310)
return MTL::LanguageVersion3_1;
#else
return MTL::LanguageVersion3_0;
#endif
}
std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
auto load_device() {
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0))
@@ -126,6 +139,49 @@ MTL::Library* load_library(
} // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
CommandEncoder::~CommandEncoder() {
enc->endEncoding();
enc->release();
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void CommandEncoder::set_output_array(
array& a,
int idx,
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
@@ -255,13 +311,9 @@ void Device::register_library(
}
}
void Device::register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func) {
void Device::register_library(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
std::string new_lib_path = lib_path_func(lib_name);
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
library_map_.insert({lib_name, new_lib});
register_library(lib_name, get_colocated_mtllib_path(lib_name));
}
}
@@ -271,7 +323,7 @@ MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name);
register_library(lib_name, get_colocated_mtllib_path(lib_name));
mtl_lib = library_map_[lib_name];
}

View File

@@ -9,38 +9,16 @@
#include <unordered_map>
#include <unordered_set>
#include <dlfcn.h>
#include <filesystem>
#include "mlx/array.h"
#include "mlx/device.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct CommandEncoder {
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
};
CommandEncoder(MTL::CommandBuffer* cbuf);
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -63,34 +41,8 @@ struct CommandEncoder {
return enc;
}
void set_input_array(const array& a, int idx, int64_t offset = 0) {
auto r_buf =
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void set_output_array(array& a, int idx, int64_t offset = 0) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
@@ -98,10 +50,7 @@ struct CommandEncoder {
return ConcurrentContext(*this);
}
~CommandEncoder() {
enc->endEncoding();
enc->release();
}
~CommandEncoder();
private:
void maybe_split();
@@ -136,10 +85,8 @@ class Device {
void register_library(
const std::string& lib_name,
const std::string& lib_path);
void register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path);
void register_library(const std::string& lib_name);
MTL::Library* get_library(const std::string& name);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <complex>
#include <map>
@@ -12,8 +12,7 @@
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/mlx.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -785,10 +784,9 @@ void nd_fft_op(
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
}
std::vector<array> copies = {temp1, temp2};
auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
}
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -0,0 +1,203 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
std::string gen_hadamard_codelet(int m) {
// Generate a O(m^2) hadamard codelet for a given M
// using the hadamard matrices above
//
// e.g. m = 2
// METAL_FUNC void hadamard_m(thread float *x) {
// float tmp[2];
// tmp[0] = + x[0] + x[1];
// tmp[1] = + x[0] - x[1];
// for (int i = 0; i < 2; i++) { x[i] = tmp[i]; }
// }
//
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
std::ostringstream source;
source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl;
if (m == 1) {
source << "}" << std::endl;
return source.str();
}
source << " float tmp[" << m << "];" << std::endl;
auto start = 1;
auto end = matrix.find('\n', start);
int index = 0;
while (end != std::string_view::npos) {
source << " tmp[" << index << "] = ";
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
source << " " << row[i] << " x[" << i << "]";
}
source << ";" << std::endl;
start = end + 1;
end = matrix.find('\n', start);
index++;
}
source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }"
<< std::endl;
source << "}" << std::endl;
return source.str();
}
void launch_hadamard(
const array& in,
array& out,
int batch_size,
int threads_per,
const std::string kernel_name,
float scale,
const Stream& s) {
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& in = inputs[0];
std::vector<array> copies;
// Only support the last axis for now
int axis = in.ndim() - 1;
auto check_input = [&copies, &s](const array& x) {
// TODO(alexbarron) pass strides to kernel to relax this constraint
bool no_copy = x.flags().row_contiguous;
if (no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
const array& in_contiguous = check_input(in);
if (in_contiguous.is_donatable()) {
out.move_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto [n, m] = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
}
int max_radix = std::min(n, 16);
// Use read_width 2 for m = 28 to avoid register spilling
int read_width = (n == 2 || m == 28) ? 2 : 4;
std::ostringstream kname;
kname << "hadamard_" << n * m << "_" << type_to_name(out);
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
kernel_source << get_template_definition(
"n" + kernel_name,
"hadamard_n",
get_type_string(in.dtype()),
n,
max_radix,
read_width);
kernel_source << get_template_definition(
"m" + kernel_name,
"hadamard_m",
get_type_string(in.dtype()),
n,
m,
read_width);
lib = d.get_library(lib_name, kernel_source.str());
}
int batch_size = in.size() / n;
int threads_per = n / max_radix;
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
//
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
//
// y = h48 @ x
//
// Upload 1:
// tmp = a.reshape(12, 4) @ h4
//
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(
in_contiguous,
temp,
batch_size,
threads_per,
"n" + kernel_name,
1.0,
s);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
} else {
launch_hadamard(
in_contiguous,
out,
batch_size,
threads_per,
"n" + kernel_name,
scale_,
s);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

View File

@@ -0,0 +1,25 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view gemv_masked_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
const device {itype}* mat [[buffer(0)]],
const device {itype}* in_vec [[buffer(1)]],
device {itype}* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device {outm_t}* out_mask [[buffer(20)]],
const device {opm_t}* mat_mask [[buffer(21)]],
const device {opm_t}* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";

View File

@@ -18,6 +18,7 @@ const char* binary();
const char* binary_two();
const char* copy();
const char* fft();
const char* hadamard();
const char* quantized();
const char* ternary();
const char* scan();
@@ -32,5 +33,6 @@ const char* steel_gemm_splitk();
const char* conv();
const char* steel_conv();
const char* steel_conv_general();
const char* gemv_masked();
} // namespace mlx::core::metal

View File

@@ -1,81 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view block_sort_kernels = R"(
template [[host_name("carg_{0}")]] [[kernel]] void
block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("ncarg_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("c_{0}")]] [[kernel]] void
block_sort<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("nc_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view multiblock_sort_kernels = R"(
template [[host_name("sort_{0}")]] [[kernel]] void
mb_block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {1}* out_vals [[buffer(1)]],
device {2}* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("partition_{0}")]] [[kernel]] void
mb_block_partition<{1}, {2}, true, {3}, {4}>(
device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals [[buffer(1)]],
const device {2}* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]);
template [[host_name("merge_{0}")]] [[kernel]] void
mb_block_merge<{1}, {2}, true, {3}, {4}>(
const device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals_in [[buffer(1)]],
const device {2}* dev_idxs_in [[buffer(2)]],
device {1}* dev_vals_out [[buffer(3)]],
device {2}* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";

View File

@@ -4,11 +4,11 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/kernels.h"
@@ -51,10 +51,12 @@ MTL::ComputePipelineState* get_unary_kernel(
std::ostringstream kernel_source;
auto u_def = get_template_definition(
"v" + lib_name, "unary_v", get_type_string(out_type), op);
auto u2_def = get_template_definition(
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
auto g_def = get_template_definition(
"g" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< u_def << g_def;
<< u_def << u2_def << g_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -71,6 +73,9 @@ void add_binary_kernels(
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
@@ -147,6 +152,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
std::ostringstream kernel_source;
const std::map<std::string, std::string> kernel_types = {
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
@@ -251,14 +257,29 @@ MTL::ComputePipelineState* get_sort_kernel(
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
block_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
bn,
tn);
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source << metal::utils() << metal::sort();
for (bool is_argsort : {true, false}) {
std::string bool_string = is_argsort ? "true" : "false";
std::string func_string = is_argsort ? "carg_" : "c_";
kernel_source << get_template_definition(
func_string + lib_name,
"block_sort",
in_type,
out_type,
bool_string,
bn,
tn);
kernel_source << get_template_definition(
"n" + func_string + lib_name,
"block_sort_nc",
in_type,
out_type,
bool_string,
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -275,14 +296,21 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
multiblock_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
bn,
tn);
kernel_source << metal::utils() << metal::sort();
std::vector<std::pair<std::string, std::string>> kernel_types = {
{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}};
for (auto [name, func] : kernel_types) {
kernel_source << get_template_definition(
name + lib_name,
func,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
"true",
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -475,6 +503,49 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn,
bool contiguous) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
: "nomask_t";
auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemv_masked()
<< fmt::format(
gemv_masked_kernel,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outm_t"_a = out_mask_type,
"opm_t"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"sm"_a = sm,
"sn"_a = sn,
"tm"_a = tm,
"tn"_a = tn,
"trans"_a = transpose_mat ? "t_" : "",
"nc"_a = contiguous ? "0" : "1");
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,

View File

@@ -151,6 +151,21 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
int n_channel_specialization,
bool small_filter);
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn,
bool contiguous);
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,

View File

@@ -4,11 +4,12 @@ set(
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h
)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only
@@ -37,7 +38,6 @@ endfunction(build_kernel)
build_kernel(arg_reduce)
build_kernel(conv steel/conv/params.h)
build_kernel(gemv steel/utils.h)
build_kernel(gemv_masked steel/utils.h)
build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
@@ -120,6 +120,7 @@ build_kernel(
steel/gemm/kernels/steel_gemm_splitk
${STEEL_HEADERS}
)
build_kernel(gemv_masked steel/utils.h)
endif()

View File

@@ -6,7 +6,7 @@
using namespace metal;
#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310)
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;

View File

@@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310)
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)

View File

@@ -36,6 +36,39 @@ template <typename T, typename U, typename Op>
c[index] = Op()(a[index], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[0], b[offset]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[offset], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[offset], b[offset]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,

View File

@@ -14,6 +14,9 @@
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \

View File

@@ -48,6 +48,48 @@ template <typename T, typename U, typename Op>
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[0], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[offset], b[0]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[offset], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,

View File

@@ -12,6 +12,9 @@
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \

View File

@@ -344,12 +344,12 @@ winograd_conv_2d_weight_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize G matrix
simdgroup_matrix<T, 8, 8> G;
simdgroup_matrix<float, 8, 8> G;
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
// Initialize Gt matrix
simdgroup_matrix<T, 8, 8> Gt;
simdgroup_matrix<float, 8, 8> Gt;
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
@@ -381,15 +381,15 @@ winograd_conv_2d_weight_transform(
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for (int c = 0; c < BC; ++c) {
simdgroup_matrix<T, 8, 8> g;
simdgroup_matrix<float, 8, 8> g;
g.thread_elements()[0] =
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] =
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = g_out.thread_elements()[0];
wt_out_1[c * O] = g_out.thread_elements()[1];
simdgroup_matrix<float, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = static_cast<T>(g_out.thread_elements()[0]);
wt_out_1[c * O] = static_cast<T>(g_out.thread_elements()[1]);
}
wt_in += BC;
@@ -433,12 +433,12 @@ winograd_conv_2d_input_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize B matrix
simdgroup_matrix<T, 8, 8> B;
simdgroup_matrix<float, 8, 8> B;
B.thread_elements()[0] = WGT::in_transform[sm][sn];
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
// Initialize Bt matrix
simdgroup_matrix<T, 8, 8> Bt;
simdgroup_matrix<float, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
@@ -493,13 +493,13 @@ winograd_conv_2d_input_transform(
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> I;
simdgroup_matrix<float, 8, 8> I;
I.thread_elements()[0] = Is[sm][sn][c];
I.thread_elements()[1] = Is[sm][sn + 1][c];
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B;
inp_out_0[c] = I_out.thread_elements()[0];
inp_out_1[c] = I_out.thread_elements()[1];
simdgroup_matrix<float, 8, 8> I_out = (Bt * I) * B;
inp_out_0[c] = static_cast<T>(I_out.thread_elements()[0]);
inp_out_1[c] = static_cast<T>(I_out.thread_elements()[1]);
}
inp_in += BC;
@@ -543,12 +543,12 @@ winograd_conv_2d_output_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize A matrix
simdgroup_matrix<T, 8, 8> B;
simdgroup_matrix<float, 8, 8> B;
B.thread_elements()[0] = WGT::out_transform[sm][sn];
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
// Initialize At matrix
simdgroup_matrix<T, 8, 8> Bt;
simdgroup_matrix<float, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
@@ -597,16 +597,16 @@ winograd_conv_2d_output_transform(
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> O_mat;
simdgroup_matrix<float, 8, 8> O_mat;
O_mat.thread_elements()[0] = out_in_0[c];
O_mat.thread_elements()[1] = out_in_1[c];
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
simdgroup_matrix<float, 8, 8> O_out = (Bt * (O_mat * B));
if ((sm < M) && (sn < M)) {
Os[sm][sn][c] = O_out.thread_elements()[0];
Os[sm][sn][c] = static_cast<T>(O_out.thread_elements()[0]);
}
if ((sm < M) && ((sn + 1) < M)) {
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
Os[sm][sn + 1][c] = static_cast<T>(O_out.thread_elements()[1]);
}
}
@@ -650,4 +650,5 @@ winograd_conv_2d_output_transform(
// clang-format off
instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@@ -16,6 +16,26 @@ template <typename T, typename U>
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_s2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
dst[offset] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
dst[offset] = static_cast<U>(src[offset]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],

View File

@@ -5,95 +5,23 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name("g" #dims "_" name)]] [[kernel]] void \
copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg1_" name )]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("gg2_" name)]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("gg3_" name)]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \
template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("s_copy" #tname, itype, otype, s) \
instantiate_copy("v_copy" #tname, itype, otype, v) \
instantiate_copy_g("copy" #tname, itype, otype) \
instantiate_copy_g_nd("copy" #tname, itype, otype)
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -83,6 +83,7 @@ float expm1f(float a) {
r = expm1f_scaled_unchecked(a, 1.0f);
/* handle severe overflow and underflow */
if (abs(a - 1.0f) > 88.0f) {
r = pow(2, a);
r = fma(r, r, -1.0f);
}
return r;

View File

@@ -0,0 +1,819 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
struct _NoMask {
char x;
constexpr METAL_FUNC operator bool() {
return true;
}
constexpr METAL_FUNC operator bool() const threadgroup {
return true;
}
constexpr METAL_FUNC operator bool() const device {
return true;
}
constexpr METAL_FUNC operator bool() const constant {
return true;
}
};
typedef struct _NoMask nomask_t;
template <typename OutT, typename InT = OutT>
struct ScaleOp {
OutT scale;
METAL_FUNC OutT apply(InT x) const {
return static_cast<OutT>(x) * scale;
}
};
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
struct GEMVKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
MLX_MTL_CONST int blockM = threadsM * TM;
MLX_MTL_CONST int blockN = threadsN * TN;
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
static_assert(
SN == 8 || SN == 16 || SN == 32,
"gemv block must have a width of 8, 16, or 32");
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
MLX_MTL_CONST bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
MLX_MTL_CONST bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for
// the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated blockM outputs
//
// Edge case handling:
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
static METAL_FUNC void
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
}
}
static METAL_FUNC void load_safe(
const device T* src,
thread T dst[TN],
const int src_offset = 0,
const int src_size = TN) {
if (src_offset + TN <= src_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
}
} else { // Edgecase
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
}
}
}
static METAL_FUNC void run(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& matrix_ld [[buffer(6)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Thread local accumulation results
thread T result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
int bm = (simdM + thrM) * TM;
int bn = (simdN + thrN) * TN;
// Block position
int out_row = tid.x * blockM + bm;
// Exit simdgroup if rows out of bound
if (out_row >= out_vec_size)
return;
// Adjust tail simdgroup to ensure in bound reads
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
// Prepare mask offsets
const constant int* out_mask_strides = mask_strides;
const constant int* mat_mask_strides =
mask_strides + (has_output_mask ? 2 : 0);
const constant int* vec_mask_strides =
mat_mask_strides + (has_operand_mask ? 2 : 0);
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
const int out_mask_offset =
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
int mat_mask_offset =
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
int vec_mask_offset = 0;
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
T out_scale{1};
// Check output mask
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
// Write zeros and return if mask is 0
if (!mask_out) {
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = T(0.);
}
}
return;
}
// Store scalar if multiplicative mask
if (has_mul_output_mask) {
out_scale = T(mask_out);
}
}
// Advance matrix
mat += out_row * matrix_ld;
// Prepare for loop
constexpr const uniform<int> loop_stride = make_uniform(blockN);
const uniform<int> in_size = make_uniform(in_vec_size);
const uniform<int> n_iter = in_size / loop_stride;
const uniform<int> last_iter = loop_stride * n_iter;
const uniform<int> leftover = in_size - last_iter;
// Loop over in_vec in blocks of blockN
for (int i = 0; i < n_iter; ++i) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_unsafe(in_vec, v_coeff, bn);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
int mat_offset = 0;
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_unsafe(mat, inter, mat_offset + bn);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
mat_offset += matrix_ld;
}
}
bn += blockN;
mat_mask_offset += mat_mask_step;
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_safe(in_vec, v_coeff, bn, in_size);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Apply out scale
if (has_mul_output_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] *= out_scale;
}
}
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
MLX_MTL_PRAGMA_UNROLL
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
result[tm] += simd_shuffle_down(result[tm], sn);
}
}
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
if (thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
tgp_results[tm] = result[tm];
}
threadgroup_barrier(mem_flags::mem_none);
if (sgN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int sgn = 1; sgn < BN; sgn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
}
}
}
}
}
// Write outputs
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
struct GEMVTKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
MLX_MTL_CONST int blockM = threadsM * TM;
MLX_MTL_CONST int blockN = threadsN * TN;
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
MLX_MTL_CONST bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
MLX_MTL_CONST bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then accumulates its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
static METAL_FUNC void run(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Thread local accumulation results
T result[TN] = {0};
T inter[TN];
T v_coeff[TM];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
const int simdM = SM * sgM;
const int simdN = SN * sgN;
int cm = (simdM + thrM);
int cn = (simdN + thrN);
int bm = cm * TM;
int bn = cn * TN;
int out_col = tid.x * blockN + bn;
// Prepare mask offsets
const constant int* out_mask_strides = mask_strides;
const constant int* mat_mask_strides =
out_mask_strides + (has_output_mask ? 2 : 0);
const constant int* vec_mask_strides =
mat_mask_strides + (has_operand_mask ? 2 : 0);
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
const int out_mask_offset =
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
int mat_mask_offset =
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
int vec_mask_offset = 0;
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
T out_scale{1};
// Check output mask
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
// Write zeros and return if mask is 0
if (!mask_out) {
if (cm == 0 && out_col < out_vec_size) {
if (out_col + TN <= out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
out_vec[out_col + tn] = T(0.);
}
} else {
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
out_vec[out_col + tn] = T(0.);
}
}
}
return;
}
// Store scalar if multiplicative mask
if (has_mul_output_mask) {
out_scale = T(mask_out);
}
}
// Prepare for loop
constexpr const uniform<int> loop_stride = make_uniform(blockM);
const uniform<int> in_size = make_uniform(in_vec_size);
const uniform<int> n_iter = in_size / loop_stride;
const uniform<int> last_iter = loop_stride * n_iter;
const uniform<int> leftover = in_size - last_iter;
// Edgecase handling
if (out_col < out_vec_size) {
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
for (int i = 0; i < n_iter; ++i) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] *= block_scale;
}
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
bm += blockM;
mat_mask_offset += mat_mask_step;
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
if (has_mul_operand_mask) {
v_coeff[tm] *= block_scale;
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
// Apply out scale
if (has_mul_output_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] *= out_scale;
}
}
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
result[tn] += simd_shuffle_down(result[tn], SN * sm);
}
}
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
if (thrM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
tgp_results[tn] = result[tn];
}
threadgroup_barrier(mem_flags::mem_none);
if (sgM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int sgm = 1; sgm < BM; sgm++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
}
}
}
}
}
// Threadgroup accumulation and writing out results
if (cm == 0 && out_col < out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch> /* Batch ndim > 1 */
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
// Update batch offsets
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if (has_output_mask) {
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
mat_mask += batch_offsets.x;
vec_mask += batch_offsets.y;
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
mat_mask += tid.z * mask_batch_strides[0];
vec_mask += tid.z * mask_batch_strides[batch_ndim];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
out_mask,
mat_mask,
vec_mask,
mask_strides,
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch> /* Batch ndim > 1 */
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
// Update batch offsets
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if (has_output_mask) {
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
mat_mask += batch_offsets.x;
vec_mask += batch_offsets.y;
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
mat_mask += tid.z * mask_batch_strides[0];
vec_mask += tid.z * mask_batch_strides[batch_ndim];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
out_mask,
mat_mask,
vec_mask,
mask_strides,
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
// clang-format off
#include <metal_simdgroup>
#include <metal_stdlib>
@@ -7,726 +8,7 @@
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
#define MLX_MTL_CONST static constant constexpr const
struct _NoMask {
char x;
constexpr METAL_FUNC operator bool() {
return true;
}
constexpr METAL_FUNC operator bool() const threadgroup {
return true;
}
constexpr METAL_FUNC operator bool() const device {
return true;
}
constexpr METAL_FUNC operator bool() const constant {
return true;
}
};
typedef struct _NoMask nomask_t;
template <typename OutT, typename InT = OutT>
struct ScaleOp {
OutT scale;
METAL_FUNC OutT apply(InT x) const {
return static_cast<OutT>(x) * scale;
}
};
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
struct GEMVKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
MLX_MTL_CONST int blockM = threadsM * TM;
MLX_MTL_CONST int blockN = threadsN * TN;
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
static_assert(
SN == 8 || SN == 16 || SN == 32,
"gemv block must have a width of 8, 16, or 32");
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
MLX_MTL_CONST bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
MLX_MTL_CONST bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for
// the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated blockM outputs
//
// Edge case handling:
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
static METAL_FUNC void
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
}
}
static METAL_FUNC void load_safe(
const device T* src,
thread T dst[TN],
const int src_offset = 0,
const int src_size = TN) {
if (src_offset + TN <= src_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
}
} else { // Edgecase
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
}
}
}
static METAL_FUNC void run(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& matrix_ld [[buffer(6)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Thread local accumulation results
thread T result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
int bm = (simdM + thrM) * TM;
int bn = (simdN + thrN) * TN;
// Block position
int out_row = tid.x * blockM + bm;
// Exit simdgroup if rows out of bound
if (out_row >= out_vec_size)
return;
// Adjust tail simdgroup to ensure in bound reads
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
// Prepare mask offsets
const constant int* out_mask_strides = mask_strides;
const constant int* mat_mask_strides =
mask_strides + (has_output_mask ? 2 : 0);
const constant int* vec_mask_strides =
mat_mask_strides + (has_operand_mask ? 2 : 0);
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
const int out_mask_offset =
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
int mat_mask_offset =
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
int vec_mask_offset = 0;
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
T out_scale{1};
// Check output mask
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
// Write zeros and return if mask is 0
if (!mask_out) {
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = T(0.);
}
}
return;
}
// Store scalar if multiplicative mask
if (has_mul_output_mask) {
out_scale = T(mask_out);
}
}
// Advance matrix
mat += out_row * matrix_ld;
// Prepare for loop
constexpr const uniform<int> loop_stride = make_uniform(blockN);
const uniform<int> in_size = make_uniform(in_vec_size);
const uniform<int> n_iter = in_size / loop_stride;
const uniform<int> last_iter = loop_stride * n_iter;
const uniform<int> leftover = in_size - last_iter;
// Loop over in_vec in blocks of blockN
for (int i = 0; i < n_iter; ++i) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_unsafe(in_vec, v_coeff, bn);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
int mat_offset = 0;
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_unsafe(mat, inter, mat_offset + bn);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
mat_offset += matrix_ld;
}
}
bn += blockN;
mat_mask_offset += mat_mask_step;
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_safe(in_vec, v_coeff, bn, in_size);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Apply out scale
if (has_mul_output_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] *= out_scale;
}
}
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
MLX_MTL_PRAGMA_UNROLL
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
result[tm] += simd_shuffle_down(result[tm], sn);
}
}
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
if (thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
tgp_results[tm] = result[tm];
}
threadgroup_barrier(mem_flags::mem_none);
if (sgN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int sgn = 1; sgn < BN; sgn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
}
}
}
}
}
// Write outputs
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
struct GEMVTKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
MLX_MTL_CONST int blockM = threadsM * TM;
MLX_MTL_CONST int blockN = threadsN * TN;
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
MLX_MTL_CONST bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
MLX_MTL_CONST bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then accumulates its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
static METAL_FUNC void run(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Thread local accumulation results
T result[TN] = {0};
T inter[TN];
T v_coeff[TM];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
const int simdM = SM * sgM;
const int simdN = SN * sgN;
int cm = (simdM + thrM);
int cn = (simdN + thrN);
int bm = cm * TM;
int bn = cn * TN;
int out_col = tid.x * blockN + bn;
// Prepare mask offsets
const constant int* out_mask_strides = mask_strides;
const constant int* mat_mask_strides =
out_mask_strides + (has_output_mask ? 2 : 0);
const constant int* vec_mask_strides =
mat_mask_strides + (has_operand_mask ? 2 : 0);
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
const int out_mask_offset =
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
int mat_mask_offset =
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
int vec_mask_offset = 0;
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
T out_scale{1};
// Check output mask
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
// Write zeros and return if mask is 0
if (!mask_out) {
if (cm == 0 && out_col < out_vec_size) {
if (out_col + TN <= out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
out_vec[out_col + tn] = T(0.);
}
} else {
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
out_vec[out_col + tn] = T(0.);
}
}
}
return;
}
// Store scalar if multiplicative mask
if (has_mul_output_mask) {
out_scale = T(mask_out);
}
}
// Prepare for loop
constexpr const uniform<int> loop_stride = make_uniform(blockM);
const uniform<int> in_size = make_uniform(in_vec_size);
const uniform<int> n_iter = in_size / loop_stride;
const uniform<int> last_iter = loop_stride * n_iter;
const uniform<int> leftover = in_size - last_iter;
// Edgecase handling
if (out_col < out_vec_size) {
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
for (int i = 0; i < n_iter; ++i) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] *= block_scale;
}
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
bm += blockM;
mat_mask_offset += mat_mask_step;
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
if (has_mul_operand_mask) {
v_coeff[tm] *= block_scale;
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
// Apply out scale
if (has_mul_output_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] *= out_scale;
}
}
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
result[tn] += simd_shuffle_down(result[tn], SN * sm);
}
}
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
if (thrM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
tgp_results[tn] = result[tn];
}
threadgroup_barrier(mem_flags::mem_none);
if (sgM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int sgm = 1; sgm < BM; sgm++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
}
}
}
}
}
// Threadgroup accumulation and writing out results
if (cm == 0 && out_col < out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch> /* Batch ndim > 1 */
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
// Update batch offsets
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if (has_output_mask) {
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
mat_mask += batch_offsets.x;
vec_mask += batch_offsets.y;
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
mat_mask += tid.z * mask_batch_strides[0];
vec_mask += tid.z * mask_batch_strides[batch_ndim];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
out_mask,
mat_mask,
vec_mask,
mask_strides,
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#include "mlx/backend/metal/kernels/gemv_masked.h"
#define instantiate_gemv_helper( \
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
@@ -754,7 +36,6 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
@@ -763,125 +44,23 @@ template <
instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on
instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc)
// clang-format off
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1)
// clang-format off
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 2, 1, 4, 8, 1, 4) \
instantiate_gemv(name, itype, 2, 1, 4, 8, 4, 4) \
instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \
instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \
instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4) // clang-format on
instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4)
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch> /* Batch ndim > 1 */
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
// Update batch offsets
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if (has_output_mask) {
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
mat_mask += batch_offsets.x;
vec_mask += batch_offsets.y;
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
mat_mask += tid.z * mask_batch_strides[0];
vec_mask += tid.z * mask_batch_strides[batch_ndim];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
out_mask,
mat_mask,
vec_mask,
mask_strides,
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_t_helper( \
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
@@ -908,7 +87,6 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
@@ -917,23 +95,20 @@ template <
instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on
instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc)
// clang-format off
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1)
// clang-format off
#define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 4, 1) \
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 1) \
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 4) \
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 8, 4) \
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4) // clang-format on
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4)
// clang-format off
instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -0,0 +1,167 @@
// Copyright © 2024 Apple Inc.
#include <metal_common>
#include <metal_compute>
#include "mlx/backend/metal/kernels/steel/defines.h"
using namespace metal;
// Thread local Hadamard transform for 2^R
template <short R>
METAL_FUNC void radix_func(thread float* x) {
constexpr short logR = __builtin_ctz(R);
short h = 1;
STEEL_PRAGMA_UNROLL
for (short s = 0; s < logR; s++) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < R / 2; i++) {
short k = i & (h - 1);
short j = ((i - k) << 1) + k;
float a = x[j];
float b = x[j + h];
x[j] = a + b;
x[j + h] = a - b;
}
h <<= 1;
}
}
template <typename T, int N, int max_radix, int read_width>
[[kernel]] void hadamard_n(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const float& scale,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute a Hadamard transform of size N = 2^k
//
// Equivalent to:
// from scipy.linalg import hadamard
// y = hadamard(len(x)) @ x
constexpr short num_threads = N / max_radix;
constexpr short logN = __builtin_ctz(N);
constexpr short logR = __builtin_ctz(max_radix);
constexpr short num_steps = logN / logR;
constexpr short logFinal = logN % logR;
constexpr short final_radix = 1 << (logFinal);
int batch_idx = elem.x * N;
short i = elem.y;
threadgroup T buf[N];
// Read values from device
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float x[max_radix];
short h = 1;
STEEL_PRAGMA_UNROLL
for (short s = 0; s < num_steps; s++) {
short k = i & (h - 1);
short j = ((i - k) << logR) + k;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < max_radix; r++) {
x[r] = buf[j + h * r];
}
radix_func<max_radix>(x);
STEEL_PRAGMA_UNROLL
for (short r = 0; r < max_radix; r++) {
buf[j + h * r] = T(x[r]);
}
h <<= logR;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Do the final radix
// e.g. max_radix = 16
// N = 1024 = 16 * 16 * 4
if (final_radix > 1) {
// Each thread does multiple butterflies
STEEL_PRAGMA_UNROLL
for (int t = 0; t < max_radix / final_radix; t++) {
short index = i + t * num_threads;
short k = index & (h - 1);
short j = ((index - k) << logFinal) + k;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < final_radix; r++) {
x[r] = buf[j + h * r];
}
radix_func<final_radix>(x);
STEEL_PRAGMA_UNROLL
for (short r = 0; r < final_radix; r++) {
buf[j + h * r] = T(x[r]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Write values to device
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = T(buf[index + r] * scale);
}
}
}
template <typename T, int N, int M, int read_width>
[[kernel]] void hadamard_m(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const float& scale,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute a Hadamard transform of size M
// using a naive O(M^2) codelet.
//
// This kernel is the second stage in the computation
// of a Hadamard transform of size M*N where N = 2^k.
int index = elem.x * grid.y + elem.y;
short i = index % (N / read_width);
int batch_idx = index / (N / read_width) * M * N;
float x[read_width][M];
STEEL_PRAGMA_UNROLL
for (short c = 0; c < M; c++) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
x[r][c] = in[batch_idx + c * N + i * read_width + r];
}
}
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
// This function is JIT compiled for M
// using the Hadamard matrix strings in `metal/hadamard.cpp`
hadamard_radix_m(x[r]);
}
// Write back to device
STEEL_PRAGMA_UNROLL
for (short c = 0; c < M; c++) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);
}
}
}

View File

@@ -34,7 +34,7 @@ template <typename T, int N_READS = RMS_N_READS>
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
x += gid * axis_size + lid * N_READS;
x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
b += b_stride * lid * N_READS;
@@ -89,7 +89,7 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer = local_normalizer[0];
// Write the outputs
out += gid * axis_size + lid * N_READS;
out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
@@ -131,7 +131,7 @@ template <typename T, int N_READS = RMS_N_READS>
threadgroup float local_mean[1];
threadgroup float local_normalizer[1];
x += gid * axis_size + lid * N_READS;
x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
b += b_stride * lid * N_READS;
@@ -188,7 +188,7 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer = local_normalizer[0];
// Write the outputs
out += gid * axis_size + lid * N_READS;
out += gid * size_t(axis_size) + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
@@ -223,8 +223,8 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
x += gid * size_t(axis_size) + lid * N_READS;
g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the computation and accumulators
@@ -321,8 +321,8 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer2 = normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
gx += gid * size_t(axis_size) + lid * N_READS;
gw += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
@@ -360,8 +360,8 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
x += gid * size_t(axis_size) + lid * N_READS;
g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the accumulators
@@ -457,8 +457,8 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer2 = normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
gx += gid * size_t(axis_size) + lid * N_READS;
gw += gid * size_t(axis_size) + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {

View File

@@ -690,12 +690,12 @@ METAL_FUNC void qvm_impl(
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits,
const bool aligned_N>
const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_t_impl(
const device T* x,
const device uint32_t* w,
@@ -812,11 +812,11 @@ METAL_FUNC void qmm_t_impl(
template <
typename T,
const int BM,
const int BK,
const int BN,
const int group_size,
const int bits>
const int bits,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_n_impl(
const device T* x,
const device uint32_t* w,
@@ -1099,7 +1099,7 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded];
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
@@ -1131,7 +1131,7 @@ template <
threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded];
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
@@ -1382,7 +1382,7 @@ template <
s_strides,
b_strides,
tid);
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
@@ -1450,6 +1450,147 @@ template <
s_strides,
b_strides,
tid);
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device T* scales [[buffer(2)]],
device T* biases [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr T eps = T(1e-7);
constexpr int simd_size = 32;
constexpr int uint8_bits = 8;
constexpr T n_bins = (1 << bits) - 1;
constexpr int packs_per_int = uint8_bits / bits;
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
static_assert(
group_size % simd_size == 0,
"Group size must be divisible by simd size.");
int in_index = index * values_per_reduce;
int out_index = index * writes_per_pack;
T w_thread[values_per_reduce];
T w_min = Limits<T>::max;
T w_max = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
T val = w[in_index + i];
w_thread[i] = val;
w_min = min(w_min, val);
w_max = max(w_max, val);
}
w_min = simd_min(w_min);
w_max = simd_max(w_max);
T scale = max((w_max - w_min) / n_bins, eps);
bool side = abs(w_min) > abs(w_max);
scale = side ? scale : -scale;
T edge = side ? w_min : w_max;
T q0 = round(edge / scale);
bool at_zero = q0 == 0.0f;
scale = at_zero ? scale : edge / q0;
T bias = at_zero ? T(0) : edge;
// Write out the scales and biases
int gindex = in_index / group_size;
if (in_index % group_size == 0) {
scales[gindex] = scale;
biases[gindex] = bias;
}
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * (i % packs_per_int));
}
if (packs_per_int < values_per_reduce &&
i % packs_per_int == packs_per_int - 1) {
out[out_index + i / packs_per_int] = output;
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 0; j < writes_per_reduce - 1; j++) {
uint8_t sval = simd_shuffle_down(val, j + 1);
output += sval << (bits * (values_per_reduce + j + i));
}
}
}
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize_scales_biases(
const device T* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device uint8_t* out [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
constexpr T n_bins = (1 << bits) - 1;
int in_index = index * packs_per_int;
int gindex = in_index / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * i);
}
}
out[index] = output;
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_dequantize(
const device uint8_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device T* out [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
int oindex = index * packs_per_int;
int gindex = oindex / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint val = w[index];
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[oindex + i] = scale * d + bias;
}
}

View File

@@ -5,241 +5,67 @@
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_qmv_fast(itype, group_size, bits) \
#define instantiate_quantized(name, type, group_size, bits) \
instantiate_kernel( \
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
qmv_fast, \
itype, \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits)
#define instantiate_qmv_fast_types(group_size, bits) \
instantiate_qmv_fast(float, group_size, bits) \
instantiate_qmv_fast(float16_t, group_size, bits) \
instantiate_qmv_fast(bfloat16_t, group_size, bits)
#define instantiate_quantized_types(name, group_size, bits) \
instantiate_quantized(name, float, group_size, bits) \
instantiate_quantized(name, float16_t, group_size, bits) \
instantiate_quantized(name, bfloat16_t, group_size, bits)
instantiate_qmv_fast_types(128, 2)
instantiate_qmv_fast_types(128, 4)
instantiate_qmv_fast_types(128, 8)
instantiate_qmv_fast_types( 64, 2)
instantiate_qmv_fast_types( 64, 4)
instantiate_qmv_fast_types( 64, 8)
instantiate_qmv_fast_types( 32, 2)
instantiate_qmv_fast_types( 32, 4)
instantiate_qmv_fast_types( 32, 8)
#define instantiate_quantized_groups(name, bits) \
instantiate_quantized_types(name, 128, bits) \
instantiate_quantized_types(name, 64, bits) \
instantiate_quantized_types(name, 32, bits)
#define instantiate_qmv(itype, group_size, bits) \
instantiate_kernel( \
"qmv_" #itype "_gs_" #group_size "_b_" #bits, \
qmv, \
itype, \
group_size, \
bits)
#define instantiate_quantized_all(name) \
instantiate_quantized_groups(name, 2) \
instantiate_quantized_groups(name, 4) \
instantiate_quantized_groups(name, 8)
#define instantiate_qmv_types(group_size, bits) \
instantiate_qmv(float, group_size, bits) \
instantiate_qmv(float16_t, group_size, bits) \
instantiate_qmv(bfloat16_t, group_size, bits)
instantiate_quantized_all(qmv_fast)
instantiate_quantized_all(qmv)
instantiate_quantized_all(qvm)
instantiate_quantized_all(qmm_n)
instantiate_quantized_all(bs_qmv_fast)
instantiate_quantized_all(bs_qmv)
instantiate_quantized_all(bs_qvm)
instantiate_quantized_all(bs_qmm_n)
instantiate_quantized_all(affine_quantize)
instantiate_quantized_all(affine_quantize_scales_biases)
instantiate_quantized_all(affine_dequantize)
instantiate_qmv_types(128, 2)
instantiate_qmv_types(128, 4)
instantiate_qmv_types(128, 8)
instantiate_qmv_types( 64, 2)
instantiate_qmv_types( 64, 4)
instantiate_qmv_types( 64, 8)
instantiate_qmv_types( 32, 2)
instantiate_qmv_types( 32, 4)
instantiate_qmv_types( 32, 8)
#define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
name, \
type, \
group_size, \
bits, \
aligned)
#define instantiate_qvm(itype, group_size, bits) \
instantiate_kernel( \
"qvm_" #itype "_gs_" #group_size "_b_" #bits, \
qvm, \
itype, \
group_size, \
bits)
#define instantiate_quantized_types_aligned(name, group_size, bits) \
instantiate_quantized_aligned(name, float, group_size, bits, true) \
instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \
instantiate_quantized_aligned(name, float, group_size, bits, false) \
instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \
instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false)
#define instantiate_qvm_types(group_size, bits) \
instantiate_qvm(float, group_size, bits) \
instantiate_qvm(float16_t, group_size, bits) \
instantiate_qvm(bfloat16_t, group_size, bits)
#define instantiate_quantized_groups_aligned(name, bits) \
instantiate_quantized_types_aligned(name, 128, bits) \
instantiate_quantized_types_aligned(name, 64, bits) \
instantiate_quantized_types_aligned(name, 32, bits)
instantiate_qvm_types(128, 2)
instantiate_qvm_types(128, 4)
instantiate_qvm_types(128, 8)
instantiate_qvm_types( 64, 2)
instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
instantiate_qvm_types( 32, 2)
instantiate_qvm_types( 32, 4)
instantiate_qvm_types( 32, 8)
#define instantiate_quantized_all_aligned(name) \
instantiate_quantized_groups_aligned(name, 2) \
instantiate_quantized_groups_aligned(name, 4) \
instantiate_quantized_groups_aligned(name, 8) \
#define instantiate_qmm_t(itype, group_size, bits, aligned_N) \
instantiate_kernel( \
"qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
qmm_t, \
itype, \
group_size, \
bits, \
aligned_N)
#define instantiate_qmm_t_types(group_size, bits) \
instantiate_qmm_t(float, group_size, bits, false) \
instantiate_qmm_t(float16_t, group_size, bits, false) \
instantiate_qmm_t(bfloat16_t, group_size, bits, false) \
instantiate_qmm_t(float, group_size, bits, true) \
instantiate_qmm_t(float16_t, group_size, bits, true) \
instantiate_qmm_t(bfloat16_t, group_size, bits, true)
instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4)
instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8)
instantiate_qmm_t_types( 32, 2)
instantiate_qmm_t_types( 32, 4)
instantiate_qmm_t_types( 32, 8)
#define instantiate_qmm_n(itype, group_size, bits) \
instantiate_kernel( \
"qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
qmm_n, \
itype, \
group_size, \
bits)
#define instantiate_qmm_n_types(group_size, bits) \
instantiate_qmm_n(float, group_size, bits) \
instantiate_qmm_n(float16_t, group_size, bits) \
instantiate_qmm_n(bfloat16_t, group_size, bits)
instantiate_qmm_n_types(128, 2)
instantiate_qmm_n_types(128, 4)
instantiate_qmm_n_types(128, 8)
instantiate_qmm_n_types( 64, 2)
instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8)
instantiate_qmm_n_types( 32, 2)
instantiate_qmm_n_types( 32, 4)
instantiate_qmm_n_types( 32, 8)
#define instantiate_bs_qmv_fast(itype, group_size, bits) \
instantiate_kernel( \
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
bs_qmv_fast, \
itype, \
group_size, \
bits)
#define instantiate_bs_qmv_fast_types(group_size, bits) \
instantiate_bs_qmv_fast(float, group_size, bits) \
instantiate_bs_qmv_fast(float16_t, group_size, bits) \
instantiate_bs_qmv_fast(bfloat16_t, group_size, bits)
instantiate_bs_qmv_fast_types(128, 2)
instantiate_bs_qmv_fast_types(128, 4)
instantiate_bs_qmv_fast_types(128, 8)
instantiate_bs_qmv_fast_types( 64, 2)
instantiate_bs_qmv_fast_types( 64, 4)
instantiate_bs_qmv_fast_types( 64, 8)
instantiate_bs_qmv_fast_types( 32, 2)
instantiate_bs_qmv_fast_types( 32, 4)
instantiate_bs_qmv_fast_types( 32, 8)
#define instantiate_bs_qmv(itype, group_size, bits) \
instantiate_kernel( \
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits, \
bs_qmv, \
itype, \
group_size, \
bits)
#define instantiate_bs_qmv_types(group_size, bits) \
instantiate_bs_qmv(float, group_size, bits) \
instantiate_bs_qmv(float16_t, group_size, bits) \
instantiate_bs_qmv(bfloat16_t, group_size, bits)
instantiate_bs_qmv_types(128, 2)
instantiate_bs_qmv_types(128, 4)
instantiate_bs_qmv_types(128, 8)
instantiate_bs_qmv_types( 64, 2)
instantiate_bs_qmv_types( 64, 4)
instantiate_bs_qmv_types( 64, 8)
instantiate_bs_qmv_types( 32, 2)
instantiate_bs_qmv_types( 32, 4)
instantiate_bs_qmv_types( 32, 8)
#define instantiate_bs_qvm(itype, group_size, bits) \
instantiate_kernel( \
"bs_qvm_" #itype "_gs_" #group_size "_b_" #bits, \
bs_qvm, \
itype, \
group_size, \
bits)
#define instantiate_bs_qvm_types(group_size, bits) \
instantiate_bs_qvm(float, group_size, bits) \
instantiate_bs_qvm(float16_t, group_size, bits) \
instantiate_bs_qvm(bfloat16_t, group_size, bits)
instantiate_bs_qvm_types(128, 2)
instantiate_bs_qvm_types(128, 4)
instantiate_bs_qvm_types(128, 8)
instantiate_bs_qvm_types( 64, 2)
instantiate_bs_qvm_types( 64, 4)
instantiate_bs_qvm_types( 64, 8)
instantiate_bs_qvm_types( 32, 2)
instantiate_bs_qvm_types( 32, 4)
instantiate_bs_qvm_types( 32, 8)
#define instantiate_bs_qmm_t(itype, group_size, bits, aligned_N) \
instantiate_kernel( \
"bs_qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
bs_qmm_t, \
itype, \
group_size, \
bits, \
aligned_N)
#define instantiate_bs_qmm_t_types(group_size, bits) \
instantiate_bs_qmm_t(float, group_size, bits, false) \
instantiate_bs_qmm_t(float16_t, group_size, bits, false) \
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, false) \
instantiate_bs_qmm_t(float, group_size, bits, true) \
instantiate_bs_qmm_t(float16_t, group_size, bits, true) \
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, true)
instantiate_bs_qmm_t_types(128, 2)
instantiate_bs_qmm_t_types(128, 4)
instantiate_bs_qmm_t_types(128, 8)
instantiate_bs_qmm_t_types( 64, 2)
instantiate_bs_qmm_t_types( 64, 4)
instantiate_bs_qmm_t_types( 64, 8)
instantiate_bs_qmm_t_types( 32, 2)
instantiate_bs_qmm_t_types( 32, 4)
instantiate_bs_qmm_t_types( 32, 8)
#define instantiate_bs_qmm_n(itype, group_size, bits) \
instantiate_kernel( \
"bs_qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
bs_qmm_n, \
itype, \
group_size, \
bits)
#define instantiate_bs_qmm_n_types(group_size, bits) \
instantiate_bs_qmm_n(float, group_size, bits) \
instantiate_bs_qmm_n(float16_t, group_size, bits) \
instantiate_bs_qmm_n(bfloat16_t, group_size, bits)
instantiate_bs_qmm_n_types(128, 2)
instantiate_bs_qmm_n_types(128, 4)
instantiate_bs_qmm_n_types(128, 8)
instantiate_bs_qmm_n_types( 64, 2)
instantiate_bs_qmm_n_types( 64, 4)
instantiate_bs_qmm_n_types( 64, 8)
instantiate_bs_qmm_n_types( 32, 2)
instantiate_bs_qmm_n_types( 32, 4)
instantiate_bs_qmm_n_types( 32, 8) // clang-format on
instantiate_quantized_all_aligned(qmm_t)
instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on

View File

@@ -43,20 +43,22 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
auto half_size = grid_dim.y - odd;
out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
auto bits = threefry2x32_hash(key, count);
auto bits = threefry2x32_hash(
key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y));
size_t idx = size_t(index.y) << 2;
for (int i = 0; i < 4; ++i) {
out[4 * count.x + i] = bits.bytes[0][i];
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
out[idx + i] = bits.bytes[1][i];
}
} else {
for (int i = 0; i < 4; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
out[idx + i] = bits.bytes[1][i];
}
}
}
@@ -77,22 +79,24 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim);
auto key = uint2(keys[k1_elem], keys[k2_elem]);
auto half_size = grid_dim.y - odd;
out += index.x * bytes_per_key;
out += size_t(index.x) * bytes_per_key;
bool drop_last = odd && (index.y == half_size);
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y);
auto bits = threefry2x32_hash(key, count);
auto bits = threefry2x32_hash(
key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y));
size_t idx = size_t(index.y) << 2;
for (int i = 0; i < 4; ++i) {
out[4 * count.x + i] = bits.bytes[0][i];
out[idx + i] = bits.bytes[0][i];
}
if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
out[idx + i] = bits.bytes[1][i];
}
} else {
for (int i = 0; i < 4; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
out[idx + i] = bits.bytes[1][i];
}
}
}

View File

@@ -24,7 +24,7 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float acc = 0;
x += gid * axis_size + lid * N_READS;
x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
@@ -62,7 +62,7 @@ template <typename T, int N_READS = RMS_N_READS>
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the outputs
out += gid * axis_size + lid * N_READS;
out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
@@ -92,7 +92,7 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float acc = 0;
x += gid * axis_size + lid * N_READS;
x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
@@ -132,7 +132,7 @@ template <typename T, int N_READS = RMS_N_READS>
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the outputs
out += gid * axis_size + lid * N_READS;
out += gid * size_t(axis_size) + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
@@ -165,8 +165,8 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
x += gid * size_t(axis_size) + lid * N_READS;
g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the computation and accumulators
@@ -233,8 +233,8 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer3 = normalizer * normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
gx += gid * size_t(axis_size) + lid * N_READS;
gw += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
gx[i] = static_cast<T>(
@@ -270,8 +270,8 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers
x += gid * axis_size + lid * N_READS;
g += gid * axis_size + lid * N_READS;
x += gid * size_t(axis_size) + lid * N_READS;
g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS;
// Allocate registers for the accumulators
@@ -337,8 +337,8 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer3 = normalizer * normalizer * normalizer;
// Write the outputs
gx += gid * axis_size + lid * N_READS;
gw += gid * axis_size + lid * N_READS;
gx += gid * size_t(axis_size) + lid * N_READS;
gw += gid * size_t(axis_size) + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {

View File

@@ -6,36 +6,17 @@
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward>
[[kernel]] void rope(
[[kernel]] void rope_single(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const int& offset,
constant const float& base,
constant const float& scale,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute the input and output indices
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
constant const size_t& stride,
uint2 pos [[thread_position_in_grid]],
uint2 grid [[threads_per_grid]]) {
// Figure out L and d.
float L = scale * static_cast<float>(pos.y + offset);
float L = scale * static_cast<float>(offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
@@ -43,6 +24,21 @@ template <typename T, bool traditional, bool forward>
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Compute the input and output indices
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x + pos.y * stride;
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x + pos.y * stride;
in_index_2 = in_index_1 + 1;
} else {
out_index_1 = pos.x + pos.y * stride;
out_index_2 = out_index_1 + grid.x;
in_index_1 = pos.x + pos.y * stride;
in_index_2 = in_index_1 + grid.x;
}
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
@@ -59,19 +55,97 @@ template <typename T, bool traditional, bool forward>
out[out_index_2] = static_cast<T>(rx2);
}
#define instantiate_rope(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
uint3 pos [[thread_position_in_grid]], \
template <typename T, bool traditional, bool forward, int N = 4>
[[kernel]] void rope(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const int& offset,
constant const float& base,
constant const float& scale,
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const size_t& n_batch,
uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Figure out L and d.
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Compute the input and output indices
size_t in_index_1, in_index_2;
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
in_index_1 += strides[0];
in_index_2 += strides[0];
out_index_1 += out_strides[0];
out_index_2 += out_strides[0];
}
}
#define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t& stride, \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \
instantiate_rope_g(name, type, traditional, forward)
// clang-format off
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
@@ -84,4 +158,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false) // clang-format on
instantiate_rope(vjp_float32, float, false, false) // clang-format on

View File

@@ -25,7 +25,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
AccT ld[N_READS];
in += gid * axis_size + lid * N_READS;
in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
@@ -83,7 +83,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
normalizer = 1 / local_normalizer[0];
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[i] = T(ld[i] * normalizer);
@@ -107,7 +107,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
in += gid * size_t(axis_size);
constexpr int SIMD_SIZE = 32;
@@ -170,7 +170,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
// Finally given the normalizer and max value we can directly write the
// softmax output
out += gid * axis_size;
out += gid * size_t(axis_size);
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;

View File

@@ -235,19 +235,21 @@ struct KernelMergeSort {
const device T* inp,
device U* out,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
const constant int& stride_segment_axis,
const constant int& in_stride_sorted_axis,
const constant int& out_stride_sorted_axis,
const constant int& in_stride_segment_axis,
const constant int& out_stride_segment_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis;
inp += tid.y * in_stride_segment_axis;
out += tid.y * out_stride_segment_axis;
// Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
: val_t(CompareOp::init);
if (ARG_SORT) {
tgp_idxs[i] = i;
@@ -264,9 +266,9 @@ struct KernelMergeSort {
// Write output
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
if (ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i];
out[i * out_stride_sorted_axis] = tgp_idxs[i];
} else {
out[i * stride_sorted_axis] = tgp_vals[i];
out[i * out_stride_sorted_axis] = tgp_vals[i];
}
}
}
@@ -282,8 +284,10 @@ template <
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
const constant int& in_stride_sorted_axis [[buffer(3)]],
const constant int& out_stride_sorted_axis [[buffer(4)]],
const constant int& in_stride_segment_axis [[buffer(5)]],
const constant int& out_stride_segment_axis [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
@@ -298,8 +302,10 @@ template <
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
in_stride_segment_axis,
out_stride_segment_axis,
tgp_vals,
tgp_idxs,
tid,
@@ -310,8 +316,10 @@ template <
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
in_stride_segment_axis,
out_stride_segment_axis,
tgp_vals,
nullptr,
tid,
@@ -331,10 +339,12 @@ template <
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
const constant int& in_stride_sorted_axis [[buffer(3)]],
const constant int& out_stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* in_nc_strides [[buffer(7)]],
const device size_t* out_nc_strides [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
@@ -342,9 +352,10 @@ template <
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out += block_idx;
auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
inp += in_block_idx;
out += out_block_idx;
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
@@ -353,7 +364,9 @@ template <
inp,
out,
size_sorted_axis,
stride_sorted_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
zero_helper,
zero_helper,
tgp_vals,
tgp_idxs,
@@ -365,7 +378,9 @@ template <
inp,
out,
size_sorted_axis,
stride_sorted_axis,
in_stride_sorted_axis,
out_stride_sorted_axis,
zero_helper,
zero_helper,
tgp_vals,
nullptr,
@@ -507,13 +522,13 @@ template <
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_partition(
[[kernel]] void mb_block_partition(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
const constant int& n_blocks [[buffer(5)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) {
@@ -528,23 +543,29 @@ mb_block_partition(
dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis;
// Find location in merge step
int merge_group = lid.x / merge_tiles;
int merge_lane = lid.x % merge_tiles;
for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
// Find location in merge step
int merge_group = i / merge_tiles;
int merge_lane = i % merge_tiles;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st,
dev_vals + B_st,
A_ed - A_st,
B_ed - B_st,
partition_at);
block_partitions[lid.x] = A_st + partition;
block_partitions[i] = A_st + partition;
}
}
template <

View File

@@ -10,28 +10,10 @@
#define instantiate_block_sort( \
name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
)]] [[kernel]] void \
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& nc_dim [[buffer(4)]], \
const device int* nc_shape [[buffer(5)]], \
const device size_t* nc_strides [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
block_sort, itype, otype, arg_sort, bn, tn) \
instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
block_sort_nc, itype, otype, arg_sort, bn, tn)
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort( \
@@ -69,43 +51,12 @@ instantiate_block_sort_long(int64, int64_t)
#define instantiate_multi_block_sort( \
vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \
device vtype* out_vals [[buffer(1)]], \
device itype* out_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& stride_sorted_axis [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype * block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& merge_tiles [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \
device vtype* dev_vals_out [[buffer(3)]], \
device itype* dev_idxs_out [[buffer(4)]], \
const constant int& size_sorted_axis [[buffer(5)]], \
const constant int& merge_tiles [[buffer(6)]], \
const constant int& num_tiles [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
mb_block_sort, vtype, itype, arg_sort, bn, tn) \
instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
mb_block_partition, vtype, itype, arg_sort, bn, tn) \
instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
mb_block_merge, vtype, itype, arg_sort, bn, tn)
#define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)

View File

@@ -10,6 +10,18 @@ template <typename T, typename Op>
d[index] = Op()(a[index], b[index], c[index]);
}
template <typename T, typename Op>
[[kernel]] void ternary_v2(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
d[offset] = Op()(a[offset], b[offset], c[offset]);
}
template <typename T, typename Op>
[[kernel]] void ternary_g_nd1(
device const bool* a,

View File

@@ -11,6 +11,7 @@
#define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \

View File

@@ -8,6 +8,16 @@ template <typename T, typename Op>
out[index] = Op()(in[index]);
}
template <typename T, typename Op>
[[kernel]] void unary_v2(
device const T* in,
device T* out,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
out[offset] = Op()(in[offset]);
}
template <typename T, typename Op>
[[kernel]] void unary_g(
device const T* in,

View File

@@ -5,8 +5,9 @@
#include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, tname, type) \
instantiate_kernel("v" #op #tname, unary_v, type, op) \
#define instantiate_unary_all(op, tname, type) \
instantiate_kernel("v" #op #tname, unary_v, type, op) \
instantiate_kernel("v2" #op #tname, unary_v2, type, op) \
instantiate_kernel("g" #op #tname, unary_g, type, op)
#define instantiate_unary_float(op) \

View File

@@ -11,187 +11,14 @@
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
// MPS Matmul fallback
///////////////////////////////////////////////////////////////////////////////
namespace {
bool use_mps() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_USE_MPS")) {
return std::string(buff_str) != "OFF";
} else {
return false;
}
};
static bool use_mps_ = get_val();
return use_mps_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
inline void mps_matmul(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
float alpha = 1.0f,
float beta = 0.0f) {
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
if (out.dtype() == float16) {
mps_dtype = MPS::DataTypeFloat16;
} else if (out.dtype() == bfloat16) {
mps_dtype = MPS::DataTypeBFloat16;
}
// Used batched MPSMatrixMultiplication if batch_size_out > 1
// We only accept the following cases:
// 1. Both a, b have batch_size_out matrices worth of data
// 2. Only one of a or b has batch_size_out matrices worth of data and
// the other has matrix worth of data
// The matrix dimensions of a and b are sure to be regularly strided
if (batch_size_out > 1) {
// No broadcasting defaults
auto batch_size_a = a.data_size() / (M * K);
auto batch_size_b = b.data_size() / (K * N);
auto matrix_stride_a = M * K;
auto matrix_stride_b = K * N;
auto matrix_stride_out = M * N;
// At this point, batch_size_a, batch_size_b show the number of matrices
// in data, no broadcasted strides considered
if (batch_size_out == std::max(batch_size_a, batch_size_b)) {
// Handle simple broadcasting
if (std::min(batch_size_a, batch_size_b) == 1) {
matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a;
matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b;
batch_size_a = batch_size_out;
batch_size_b = batch_size_out;
}
// Only proceed if broadcasting between a and b is simple
// At this point, batch_size_a, batch_size_b show the number of matrices
// after broadcasting
if (batch_size_a == batch_size_b) {
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
(M * K) / lda,
lda,
batch_size_a,
lda * a.itemsize(),
(matrix_stride_a * a.itemsize()),
mps_dtype);
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
(K * N) / ldb,
ldb,
batch_size_b,
ldb * b.itemsize(),
(matrix_stride_b * b.itemsize()),
mps_dtype);
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
M,
N,
batch_size_out,
N * out.itemsize(),
matrix_stride_out * out.itemsize(),
mps_dtype);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
auto kernel = MPS::MatrixMultiplication::alloc()->init(
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
auto command_buffer = d.get_command_buffer(s.index);
kernel->setBatchSize(batch_size_out);
kernel->setBatchStart(0);
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
command_buffer->addCompletedHandler(
[a_mat, b_mat, out_mat, kernel, copies](
MTL::CommandBuffer*) mutable {
a_mat->release();
b_mat->release();
out_mat->release();
kernel->release();
copies.clear();
});
return;
}
}
}
// Schedule as many calls to MPSMatrixMultiplication as needed otherwise
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype);
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype);
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
batch_size_out * M, N, N * out.itemsize(), mps_dtype);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
auto kernel = MPS::MatrixMultiplication::alloc()->init(
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
auto command_buffer = d.get_command_buffer(s.index);
for (int i = 0; i < batch_size_out; ++i) {
auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda;
auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb;
kernel->setLeftMatrixOrigin({a_row, 0, 0});
kernel->setRightMatrixOrigin({b_row, 0, 0});
kernel->setResultMatrixOrigin({i * static_cast<size_t>(M), 0, 0});
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
}
command_buffer->addCompletedHandler(
[a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable {
a_mat->release();
b_mat->release();
out_mat->release();
kernel->release();
copies.clear();
});
}
inline auto collapse_batches(const array& a, const array& b) {
// Get and check the shape for the batched dims
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
@@ -860,26 +687,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
/////////////////////////////////////////////////////////////////////////////
// Gemm specialization
if (use_mps()) {
d.end_encoding(s.index);
return mps_matmul(
s,
d,
a,
b,
out,
M,
N,
K,
batch_size_out,
a_cols,
b_cols,
a_transposed,
b_transposed,
copies);
}
return steel_matmul(
/* const Stream& s = */ s,
/* metal::Device& d = */ d,
@@ -1529,8 +1336,22 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "_nc" << !contiguous_kernel;
// Encode and dispatch kernel
auto kernel = get_gemv_masked_kernel(
d,
kname.str(),
out,
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
transpose_mat,
bm,
bn,
sm,
sn,
tm,
tn,
contiguous_kernel);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;

View File

@@ -1,14 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -48,4 +40,4 @@ void steel_matmul(
std::vector<size_t> A_batch_stride = {},
std::vector<size_t> B_batch_stride = {});
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -75,6 +75,9 @@ std::function<void()> make_task(array arr, bool signal) {
if (!arr.is_tracer()) {
arr.detach();
}
for (auto& out : outputs) {
out.set_status(array::Status::available);
}
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
d.end_encoding(s.index);

View File

@@ -1,370 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <Metal/Metal.hpp>
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
namespace MTL::Private::Class {
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
_MTL_PRIVATE_DEF_CLS(MPSVector);
_MTL_PRIVATE_DEF_CLS(MPSKernel);
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
} // namespace MTL::Private::Class
namespace MTL::Private::Selector {
_MTL_PRIVATE_DEF_SEL(
matrixDescriptorWithRows_columns_rowBytes_dataType,
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(rows, "rows");
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
_MTL_PRIVATE_DEF_SEL(
initWithDevice_,
"initWithDevice:transposeLeft:transposeRight:"
"resultRows:resultColumns:interiorColumns:alpha:beta:");
_MTL_PRIVATE_DEF_SEL(
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
_MTL_PRIVATE_DEF_SEL(
vectorDescriptorWithLength_dataType,
"vectorDescriptorWithLength:dataType:");
_MTL_PRIVATE_DEF_SEL(
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(
initWithDevice_transpose_rows_columns_alpha_beta,
"initWithDevice:transpose:rows:columns:alpha:beta:");
_MTL_PRIVATE_DEF_SEL(
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
} // namespace MTL::Private::Selector
namespace MPS {
typedef enum DataType : uint32_t {
DataTypeFloatBit = 0x10000000,
DataTypeAlternateEncodingBit = 0x80000000,
DataTypeFloat16 = DataTypeFloatBit | 16,
DataTypeFloat32 = DataTypeFloatBit | 32,
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
} DataType;
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
public:
static class MatrixDescriptor* matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger rowBytes,
NS::UInteger dataType);
static class MatrixDescriptor* matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger matrices,
NS::UInteger rowBytes,
NS::UInteger matrixBytes,
NS::UInteger dataType);
NS::UInteger rows() const;
};
class Matrix : public NS::Referencing<Matrix> {
public:
static class Matrix* alloc();
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
};
class Kernel : public NS::Referencing<Kernel> {
public:
NS::String* label() const;
MTL::Device* device() const;
};
class MatrixMultiplication
: public NS::Referencing<MatrixMultiplication, Kernel> {
public:
static class MatrixMultiplication* alloc();
MatrixMultiplication* init(
MTL::Device* device,
bool transposeLeft,
bool transposeRight,
NS::UInteger resultRows,
NS::UInteger resultColumns,
NS::UInteger interiorColumns,
double alpha,
double beta);
void encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* leftMatrix,
Matrix* rightMatrix,
Matrix* resultMatrix);
void setLeftMatrixOrigin(MTL::Origin origin);
void setRightMatrixOrigin(MTL::Origin origin);
void setResultMatrixOrigin(MTL::Origin origin);
void setBatchStart(NS::UInteger batchStart);
void setBatchSize(NS::UInteger batchSize);
};
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
public:
static class VectorDescriptor* vectorDescriptor(
NS::UInteger length,
NS::UInteger dataType);
static class VectorDescriptor* vectorDescriptor(
NS::UInteger length,
NS::UInteger vectors,
NS::UInteger vectorBytes,
NS::UInteger dataType);
};
class Vector : public NS::Referencing<Vector> {
public:
static class Vector* alloc();
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
};
class MatrixVectorMultiplication
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
public:
static class MatrixVectorMultiplication* alloc();
MatrixVectorMultiplication* init(
MTL::Device* device,
bool transpose,
NS::UInteger rows,
NS::UInteger columns,
double alpha,
double beta);
void encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* inputMatrix,
Vector* inputVector,
Vector* resultVector);
};
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger rowBytes,
NS::UInteger dataType) {
return Object::sendMessage<MatrixDescriptor*>(
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
rows,
columns,
rowBytes,
dataType);
}
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger matrices,
NS::UInteger rowBytes,
NS::UInteger matrixBytes,
NS::UInteger dataType) {
return Object::sendMessage<MatrixDescriptor*>(
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
_MPS_PRIVATE_SEL(
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
rows,
columns,
matrices,
rowBytes,
matrixBytes,
dataType);
}
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
}
_MTL_INLINE Matrix* Matrix::alloc() {
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
}
_MTL_INLINE Matrix* Matrix::init(
MTL::Buffer* buffer,
MatrixDescriptor* descriptor) {
return Object::sendMessage<Matrix*>(
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
}
_MTL_INLINE Matrix* Matrix::init(
const MTL::Buffer* buffer,
MatrixDescriptor* descriptor) {
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
}
_MTL_INLINE NS::String* Kernel::label() const {
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
}
_MTL_INLINE MTL::Device* Kernel::device() const {
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
}
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
return NS::Object::alloc<MatrixMultiplication>(
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
}
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
MTL::Device* device,
bool transposeLeft,
bool transposeRight,
NS::UInteger resultRows,
NS::UInteger resultColumns,
NS::UInteger interiorColumns,
double alpha,
double beta) {
return Object::sendMessage<MatrixMultiplication*>(
this,
_MPS_PRIVATE_SEL(initWithDevice_),
device,
transposeLeft,
transposeRight,
resultRows,
resultColumns,
interiorColumns,
alpha,
beta);
}
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* leftMatrix,
Matrix* rightMatrix,
Matrix* resultMatrix) {
return Object::sendMessage<void>(
this,
_MPS_PRIVATE_SEL(
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
commandBuffer,
leftMatrix,
rightMatrix,
resultMatrix);
}
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
}
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
}
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
NS::UInteger length,
NS::UInteger dataType) {
return Object::sendMessage<VectorDescriptor*>(
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
length,
dataType);
}
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
NS::UInteger length,
NS::UInteger vectors,
NS::UInteger vectorBytes,
NS::UInteger dataType) {
return Object::sendMessage<VectorDescriptor*>(
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
length,
vectors,
vectorBytes,
dataType);
}
_MTL_INLINE Vector* Vector::alloc() {
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
}
_MTL_INLINE Vector* Vector::init(
MTL::Buffer* buffer,
VectorDescriptor* descriptor) {
return Object::sendMessage<Vector*>(
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
}
_MTL_INLINE Vector* Vector::init(
const MTL::Buffer* buffer,
VectorDescriptor* descriptor) {
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
}
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
return NS::Object::alloc<MatrixVectorMultiplication>(
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
}
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
MTL::Device* device,
bool transpose,
NS::UInteger rows,
NS::UInteger columns,
double alpha,
double beta) {
return Object::sendMessage<MatrixVectorMultiplication*>(
this,
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
device,
transpose,
rows,
columns,
alpha,
beta);
}
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* inputMatrix,
Vector* inputVector,
Vector* resultVector) {
return Object::sendMessage<void>(
this,
_MPS_PRIVATE_SEL(
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
commandBuffer,
inputMatrix,
inputVector,
resultVector);
}
} // namespace MPS

View File

@@ -169,6 +169,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const std::optional<array>&,
const std::optional<array>&,
bool,
int,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,

View File

@@ -171,7 +171,7 @@ void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void CustomVJP::eval_gpu(
void CustomTransforms::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
eval(inputs, outputs);
@@ -273,7 +273,18 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
copy_gpu(in, out, CopyType::General);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto out_strides = make_contiguous_strides<size_t>(in.shape());
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
out_strides,
0,
0,
CopyType::General,
stream());
} else {
shared_buffer_reshape(in, out_strides, out);
}

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -47,8 +48,8 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
<< "_fast";
kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -270,8 +271,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_fast";
kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -513,4 +514,82 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
bool compute_scale_bias = inputs.size() == 1;
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::vector<array> copies;
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto w = ensure_row_contiguous(w_pre);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(w, 0);
if (!compute_scale_bias) {
auto& scales_pre = inputs[1];
auto& biases_pre = inputs[2];
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_output_array(out, 3);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
compute_encoder.set_output_array(out, 1);
compute_encoder.set_output_array(scales, 2);
compute_encoder.set_output_array(biases, 3);
}
std::ostringstream kname;
auto type_string = dequantize_ ? get_type_string(out.dtype())
: get_type_string(w_pre.dtype());
auto kernel_func = "affine_quantize_scales_biases";
if (dequantize_) {
kernel_func = "affine_dequantize";
} else if (compute_scale_bias) {
kernel_func = "affine_quantize";
}
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
auto template_def = get_template_definition(
kname.str(), kernel_func, type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
// Treat uint32 as uint8 in kernel
constexpr int uint8_per_uint32 = 4;
constexpr int simd_size = 32;
int packs_per_int = 8 / bits_;
int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int;
size_t nthreads =
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto group_dims = MTL::Size(thread_group_size, 1, 1);
auto grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

View File

@@ -5,6 +5,8 @@
namespace mlx::core::fast {
constexpr int n_per_thread = 4;
void RoPE::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
@@ -62,8 +64,11 @@ void RoPE::eval_gpu(
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
std::ostringstream kname;
kname << "rope_" << (forward_ ? "" : "vjp_")
kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -72,18 +77,28 @@ void RoPE::eval_gpu(
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3);
compute_encoder->setBytes(&offset_, sizeof(int), 4);
compute_encoder->setBytes(&base, sizeof(float), 5);
compute_encoder->setBytes(&scale_, sizeof(float), 6);
compute_encoder->setBytes(&offset_, sizeof(int), 2);
compute_encoder->setBytes(&base, sizeof(float), 3);
compute_encoder->setBytes(&scale_, sizeof(float), 4);
int dim0 = dims_ / 2;
int dim1 = in.shape(-2);
int dim2 = in.size() / mat_size;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder.dispatchThreads(grid_dims, group_dims);
size_t n_batch = in.size() / mat_size;
if (single) {
compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
uint32_t dim0 = dims_ / 2;
auto group_dims = get_block_dims(dim0, n_batch, 1);
auto grid_dims = MTL::Size(dim0, n_batch, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
} // namespace mlx::core::fast

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include "mlx/backend/metal/copy.h"

View File

@@ -24,8 +24,11 @@ void single_block_sort(
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
std::vector<size_t> nc_str = in.strides();
nc_str.erase(nc_str.begin() + axis);
std::vector<size_t> in_nc_str = in.strides();
in_nc_str.erase(in_nc_str.begin() + axis);
std::vector<size_t> out_nc_str = out.strides();
out_nc_str.erase(out_nc_str.begin() + axis);
std::vector<int> nc_shape = in.shape();
nc_shape.erase(nc_shape.begin() + axis);
@@ -33,21 +36,28 @@ void single_block_sort(
int nc_dim = nc_shape.size();
int size_sorted_axis = in.shape(axis);
int stride_sorted_axis = in.strides()[axis];
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
int in_stride_sorted_axis = in.strides()[axis];
int out_stride_sorted_axis = out.strides()[axis];
int in_stride_segment_axis =
*std::min_element(in_nc_str.begin(), in_nc_str.end());
int out_stride_segment_axis =
*std::min_element(out_nc_str.begin(), out_nc_str.end());
// Check if remaining strides are contiguous
bool contiguous_write = true;
if (axis != in.ndim() - 1 && axis != 0) {
for (int i = 0; i < nc_str.size() - 1; ++i) {
size_t expected = nc_str[i + 1] * nc_str[i + 1];
contiguous_write &= (nc_str[i] == expected);
}
}
// We can only use the contiguous kernel if the sorted axis
// has the largest or smallest stride.
// We also need the input to be contiguous
bool contiguous = in.flags().contiguous;
auto check_strides = [](array x, int sort_stride) {
int min_stride = *std::min_element(x.strides().begin(), x.strides().end());
int max_stride = *std::max_element(x.strides().begin(), x.strides().end());
return sort_stride == min_stride || sort_stride == max_stride;
};
contiguous &= check_strides(in, in_stride_sorted_axis);
contiguous &= check_strides(out, out_stride_sorted_axis);
// Prepare kernel name
std::ostringstream kname;
kname << (contiguous_write ? "c" : "nc");
kname << (contiguous ? "c" : "nc");
if (argsort) {
kname << "arg";
}
@@ -64,14 +74,17 @@ void single_block_sort(
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
if (contiguous_write) {
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
if (contiguous) {
compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
} else {
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
}
MTL::Size group_dims = MTL::Size(bn, 1, 1);
@@ -164,6 +177,8 @@ void multi_block_sort(
array dev_vals_out = dev_vals_1;
array dev_idxs_out = dev_idxs_1;
int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
@@ -186,8 +201,9 @@ void multi_block_sort(
compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);

View File

@@ -32,6 +32,7 @@ void ternary_op_gpu_inplace(
auto& strides_c = strides[2];
auto& strides_out = strides[3];
bool use_2d = out.data_size() > UINT_MAX;
std::string kernel_name;
{
std::ostringstream kname;
@@ -40,6 +41,8 @@ void ternary_op_gpu_inplace(
if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) {
kname << shape.size();
}
} else if (use_2d) {
kname << "v2";
} else {
kname << "v";
}

View File

@@ -25,11 +25,14 @@ void unary_op_gpu_inplace(
auto& d = metal::device(s.device);
std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out);
size_t nthreads = contig ? in.data_size() : in.size();
bool use_2d = nthreads > UINT32_MAX;
std::string kernel_name =
(contig ? (use_2d ? "v2" : "v") : "g") + op + type_to_name(out);
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
size_t nthreads = contig ? in.data_size() : in.size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;

116
mlx/backend/metal/utils.cpp Normal file
View File

@@ -0,0 +1,116 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/utils.h"
using namespace mlx;
namespace mlx::core {
std::string type_to_name(const array& a) {
std::string tname;
switch (a.dtype()) {
case bool_:
tname = "bool_";
break;
case uint8:
tname = "uint8";
break;
case uint16:
tname = "uint16";
break;
case uint32:
tname = "uint32";
break;
case uint64:
tname = "uint64";
break;
case int8:
tname = "int8";
break;
case int16:
tname = "int16";
break;
case int32:
tname = "int32";
break;
case int64:
tname = "int64";
break;
case float16:
tname = "float16";
break;
case float32:
tname = "float32";
break;
case bfloat16:
tname = "bfloat16";
break;
case complex64:
tname = "complex64";
break;
}
return tname;
}
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == 10) {
break;
}
}
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
}
MTL::Size get_2d_grid_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
// Dims with strides of 0 are ignored as they
// correspond to broadcasted dimensions
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
return MTL::Size(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
} // namespace mlx::core

View File

@@ -8,8 +8,6 @@
namespace mlx::core {
namespace {
using metal::CommandEncoder;
template <typename T>
@@ -27,82 +25,22 @@ set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
return set_vector_bytes(enc, vec, vec.size(), idx);
}
std::string type_to_name(const array& a) {
std::string tname;
switch (a.dtype()) {
case bool_:
tname = "bool_";
break;
case uint8:
tname = "uint8";
break;
case uint16:
tname = "uint16";
break;
case uint32:
tname = "uint32";
break;
case uint64:
tname = "uint64";
break;
case int8:
tname = "int8";
break;
case int16:
tname = "int16";
break;
case int32:
tname = "int32";
break;
case int64:
tname = "int64";
break;
case float16:
tname = "float16";
break;
case float32:
tname = "float32";
break;
case bfloat16:
tname = "bfloat16";
break;
case complex64:
tname = "complex64";
break;
}
return tname;
}
std::string type_to_name(const array& a);
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == 10) {
break;
}
}
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
}
// Compute the thread block dimensions which fit the given
// input dimensions.
// - The thread block dimensions will be powers of two
// - The thread block size will be less than 1024
MTL::Size get_block_dims(int dim0, int dim1, int dim2);
// Computes a 2D grid where each element is < UINT_MAX
// Assumes:
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
// - shape and strides correspond to a contiguous (no holes) but
// possibly broadcasted array
MTL::Size get_2d_grid_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides);
inline NS::String* make_string(std::ostringstream& os) {
std::string string = os.str();
@@ -130,23 +68,6 @@ inline void debug_set_primitive_buffer_label(
#endif
}
bool is_power_of_2(int n) {
return ((n & (n - 1)) == 0) && n != 0;
}
int next_power_of_2(int n) {
if (is_power_of_2(n)) {
return n;
}
return pow(2, std::ceil(std::log2(n)));
}
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
} // namespace
std::string get_primitive_string(Primitive* primitive);
} // namespace mlx::core

View File

@@ -42,7 +42,7 @@ NO_CPU(Convolution)
NO_CPU(Copy)
NO_CPU(Cos)
NO_CPU(Cosh)
NO_CPU_MULTI(CustomVJP)
NO_CPU_MULTI(CustomTransforms)
NO_CPU_MULTI(Depends)
NO_CPU(Divide)
NO_CPU_MULTI(DivMod)
@@ -61,6 +61,7 @@ NO_CPU(GatherMM)
NO_CPU(GatherQMM)
NO_CPU(Greater)
NO_CPU(GreaterEqual)
NO_CPU(Hadamard)
NO_CPU(Less)
NO_CPU(LessEqual)
NO_CPU(Load)

View File

@@ -43,7 +43,7 @@ NO_GPU(Convolution)
NO_GPU(Copy)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU_MULTI(CustomVJP)
NO_GPU_MULTI(CustomTransforms)
NO_GPU_MULTI(Depends)
NO_GPU(Divide)
NO_GPU_MULTI(DivMod)
@@ -62,6 +62,7 @@ NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
@@ -117,6 +118,7 @@ NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
} // namespace fast
} // namespace mlx::core

View File

@@ -266,6 +266,10 @@ class CompilerCache {
cache_.erase(fun_id);
}
void clear() {
cache_.clear();
}
private:
CompilerCache() {
// Make sure the allocator is fully
@@ -859,6 +863,10 @@ void compile_erase(std::uintptr_t fun_id) {
detail::compiler_cache().erase(fun_id);
}
void compile_clear_cache() {
detail::compiler_cache().clear();
}
} // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile(

View File

@@ -1,11 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdint>
#include <sstream>
#include <vector>
#include "mlx/dtype.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -178,67 +175,4 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
[static_cast<uint32_t>(b)];
}
// Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t) {
std::ostringstream r;
if (size_of(t) > 1)
r << (is_big_endian() ? ">" : "<");
else
r << "|";
r << kindof(t) << (int)size_of(t);
return r.str();
}
// Dtype from array protocol type string
Dtype dtype_from_array_protocol(std::string_view t) {
if (t.length() == 2 || t.length() == 3) {
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
if (r == "V2") {
return bfloat16;
}
uint8_t size = r[1] - '0';
switch (r[0]) {
case 'b': {
if (size == 1)
return bool_;
}
case 'i': {
if (size == 1)
return int8;
else if (size == 2)
return int16;
else if (size == 4)
return int32;
else if (size == 8)
return int64;
}
case 'u': {
if (size == 1)
return uint8;
else if (size == 2)
return uint16;
else if (size == 4)
return uint32;
else if (size == 8)
return uint64;
}
case 'f': {
if (size == 2)
return float16;
else if (size == 4)
return float32;
}
case 'c': {
return complex64;
}
}
}
throw std::invalid_argument(
"[from_str] Invalid array protocol type-string: " + std::string(t));
}
} // namespace mlx::core

View File

@@ -4,8 +4,6 @@
#include <complex>
#include <cstdint>
#include <ostream>
#include <string>
#include "mlx/types/complex.h"
#include "mlx/types/half_types.h"
@@ -103,9 +101,4 @@ struct TypeToDtype {
operator Dtype();
};
// Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t);
// Dtype from array protocol type string
Dtype dtype_from_array_protocol(std::string_view t);
} // namespace mlx::core

859
mlx/einsum.cpp Normal file
View File

@@ -0,0 +1,859 @@
// Copyright © 2024 Apple Inc.
#include <numeric>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include "mlx/einsum.h"
#include "mlx/ops.h"
namespace mlx::core {
namespace {
// The MLX einsum implementation is based on NumPy (which is based on
// opt_einsum):
// https://github.com/numpy/numpy/blob/1d49c7f7ff527c696fc26ab2278ad51632a66660/numpy/_core/einsumfunc.py#L743
// https://github.com/dgasmith/opt_einsum
using CharSet = std::unordered_set<char>;
// A helper struct to hold the string and set
// representation of a subscript to avoid needing
// to recompute the set
struct Subscript {
Subscript(std::string str, CharSet set)
: str(std::move(str)), set(std::move(set)) {};
std::string str;
CharSet set;
};
struct PathInfo {
size_t naive_cost;
size_t naive_scaling;
size_t optimized_cost;
size_t optimized_scaling;
size_t largest_term;
};
struct PathNode {
PathNode(
std::vector<Subscript> inputs,
Subscript output,
std::vector<int> positions)
: inputs(std::move(inputs)),
output(std::move(output)),
positions(std::move(positions)) {};
std::vector<Subscript> inputs;
Subscript output;
std::vector<int> positions;
};
// Parse the comma separated subscripts into a vector of strings. If the
// output subscripts are missing they are inferred.
//
// For example:
// "ij,jk -> ik" becomes {{"ij", "jk"}, "ik"}
// "ij,jk" becomes {{"ij", "jk"}, "ik"}
std::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {
std::string lhs, rhs;
// Start by removing all white space
subscripts.erase(
std::remove(subscripts.begin(), subscripts.end(), ' '), subscripts.end());
if (auto pos = subscripts.find("->"); pos != std::string::npos) {
// Explicit mode
lhs = subscripts.substr(0, pos);
rhs = subscripts.substr(pos + 2);
} else {
// Implicit mode:
// - repeats are summed
// - remaining output axes are ordered alphabetically
lhs = subscripts;
std::unordered_map<char, int> temp;
for (auto& c : subscripts) {
if (c == ',') {
continue;
}
auto inserted = temp.insert({c, 0});
inserted.first->second++;
}
for (auto& k : temp) {
if (k.second == 1) {
rhs += k.first;
}
}
std::sort(rhs.begin(), rhs.end());
}
std::vector<std::string> input_list;
std::stringstream ss(lhs);
std::string token;
while (getline(ss, token, ',')) {
input_list.push_back(token);
}
return {input_list, rhs};
}
// Check if two sets are disjoint
bool disjoint(const CharSet& x, const CharSet& y) {
for (auto& c : x) {
if (y.find(c) != y.end()) {
return false;
}
}
return true;
}
template <typename T>
size_t term_size(const T& term, std::unordered_map<char, int> dict) {
size_t size = 1;
for (auto c : term) {
size *= dict[c];
}
return size;
}
size_t flop_count(
const CharSet& term,
bool inner,
int num_terms,
std::unordered_map<char, int> dict) {
size_t size = term_size(term, dict);
auto op_factor = 1;
if ((num_terms - 1) > op_factor) {
op_factor = num_terms - 1;
}
if (inner) {
op_factor += 1;
}
return size * op_factor;
}
std::pair<size_t, int> compute_cost_and_scaling(
const std::vector<Subscript>& inputs,
const Subscript& output,
std::unordered_map<char, int> dim_map) {
CharSet contractions;
for (auto& in : inputs) {
contractions.insert(in.set.begin(), in.set.end());
}
bool inner = false;
for (auto c : contractions) {
if (output.set.find(c) == output.set.end()) {
inner = true;
break;
}
}
auto cost = flop_count(contractions, inner, inputs.size(), dim_map);
return {cost, contractions.size()};
}
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
std::vector<Subscript> inputs,
const Subscript& output,
std::unordered_map<char, int> dim_map,
size_t cost_limit,
size_t memory_limit) {
// Helper struct for building the greedy path
struct Contraction {
Contraction(
size_t size,
size_t cost,
CharSet output,
int dims,
int x,
int y)
: size(size),
cost(cost),
output(std::move(output)),
dims(dims),
x(x),
y(y) {};
int64_t size; // Size difference, can be negative
size_t cost;
CharSet output;
int dims; // Number of dimensions in the contraction
int x;
int y;
};
// Start by iterating over all possible combinations
std::vector<std::pair<int, int>> pos_pairs;
for (int i = 0; i < inputs.size(); ++i) {
for (int j = i + 1; j < inputs.size(); ++j) {
pos_pairs.emplace_back(i, j);
}
}
std::vector<PathNode> path;
std::vector<Contraction> possible_contractions;
size_t path_cost = 0;
int path_scaling = 0;
auto num_in = inputs.size();
for (int i = 0; i < num_in - 1; ++i) {
auto add_contraction = [&](int p1, int p2) {
CharSet new_term;
CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end());
contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end());
for (int i = 0; i < inputs.size(); i++) {
if (i == p1 || i == p2) {
continue;
}
auto& in = inputs[i].set;
for (auto c : in) {
if (contractions.find(c) != contractions.end()) {
new_term.insert(c);
}
}
}
for (auto c : output.set) {
if (contractions.find(c) != contractions.end()) {
new_term.insert(c);
}
}
// Ignore if:
// - The size of the new result is greater than the memory limit
// - The cost is larger than the naive cost
auto new_size = term_size(new_term, dim_map);
if (new_size > memory_limit) {
return;
}
int64_t removed_size = term_size(inputs[p1].set, dim_map) +
term_size(inputs[p2].set, dim_map) - new_size;
bool inner = contractions.size() > new_term.size();
auto cost = flop_count(contractions, inner, 2, dim_map);
if (path_cost + cost > cost_limit) {
return;
}
possible_contractions.emplace_back(
removed_size, cost, std::move(new_term), contractions.size(), p1, p2);
};
for (auto& [p1, p2] : pos_pairs) {
// Ignore outer products
if (!disjoint(inputs[p1].set, inputs[p2].set)) {
add_contraction(p1, p2);
}
}
// If there's nothing in the contraction list,
// go over the pairs again without ignoring outer products
if (possible_contractions.empty()) {
for (auto& [p1, p2] : pos_pairs) {
add_contraction(p1, p2);
}
}
if (possible_contractions.empty()) {
// Default to naive einsum for the remaining inputs
std::vector<int> positions(inputs.size());
std::iota(positions.begin(), positions.end(), 0);
auto [cost, scale] = compute_cost_and_scaling(inputs, output, dim_map);
path.emplace_back(std::move(inputs), output, std::move(positions));
path_cost += cost;
path_scaling = std::max(scale, path_scaling);
break;
}
// Find the best contraction
auto& best = *std::min_element(
possible_contractions.begin(),
possible_contractions.end(),
[](const auto& x, const auto& y) {
return x.size > y.size || (x.size == y.size && x.cost < y.cost);
});
path_scaling = std::max(best.dims, path_scaling);
// Construct the output subscripts
std::string out_str(best.output.begin(), best.output.end());
// TODO, sorting by dimension size seems suboptimal?
std::sort(out_str.begin(), out_str.end(), [&dim_map](auto x, auto y) {
return dim_map[x] < dim_map[y];
});
Subscript new_output(std::move(out_str), std::move(best.output));
// Add the chosen contraction to the path
{
std::vector<Subscript> in_terms;
in_terms.push_back(std::move(inputs[best.x]));
in_terms.push_back(std::move(inputs[best.y]));
path.emplace_back(
std::move(in_terms), new_output, std::vector<int>{best.x, best.y});
}
// Remove used terms
inputs.erase(inputs.begin() + best.y);
inputs.erase(inputs.begin() + best.x);
// Add the new result
inputs.push_back(std::move(new_output));
// Update the existing contractions based on the selected one
std::vector<Contraction> updated_contractions;
for (auto& contraction : possible_contractions) {
// Drop contractions which contain either selected term
if (contraction.x == best.x || contraction.x == best.y ||
contraction.y == best.x || contraction.y == best.y) {
continue;
}
// Update the positions of other contractions
int x =
contraction.x - (contraction.x > best.x) - (contraction.x > best.y);
int y =
contraction.y - (contraction.y > best.x) - (contraction.y > best.y);
contraction.x = x;
contraction.y = y;
updated_contractions.push_back(std::move(contraction));
}
pos_pairs.clear();
for (int i = 0; i < inputs.size() - 1; ++i) {
pos_pairs.emplace_back(i, inputs.size() - 1);
}
path_cost += best.cost;
possible_contractions = std::move(updated_contractions);
}
return {path, path_cost, path_scaling};
}
// Assumes inputs have already have had repeats and single axis sums collapsed
bool can_dot(const std::vector<Subscript>& inputs, const Subscript& output) {
if (inputs.size() != 2) {
return false;
}
for (auto c : inputs[0].set) {
// Use batched tensordot if anything is being contracted
if (output.set.find(c) == output.set.end()) {
return true;
}
}
return false;
}
array batch_tensordot(
array a,
array b,
std::vector<int> a_contract,
std::vector<int> a_batch,
std::vector<int> a_concat,
std::vector<int> b_contract,
std::vector<int> b_batch,
std::vector<int> b_concat,
StreamOrDevice s) {
// Broadcast contracting dimensions
{
auto a_shape = a.shape();
auto b_shape = b.shape();
for (int i = 0; i < a_contract.size(); ++i) {
auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i]));
a_shape[a_contract[i]] = d;
b_shape[b_contract[i]] = d;
}
a = broadcast_to(a, a_shape, s);
b = broadcast_to(b, b_shape, s);
}
auto transpose_reshape = [&s](
const array& x,
const std::vector<int>& i,
const std::vector<int>& j,
const std::vector<int>& k) {
std::vector<int> reorder(i.begin(), i.end());
reorder.insert(reorder.end(), j.begin(), j.end());
reorder.insert(reorder.end(), k.begin(), k.end());
int size1 = 1;
for (auto s : j) {
size1 *= x.shape(s);
}
int size2 = 1;
for (auto s : k) {
size2 *= x.shape(s);
}
std::vector<int> shape;
for (auto ax : i) {
shape.push_back(x.shape(ax));
}
shape.push_back(size1);
shape.push_back(size2);
return reshape(transpose(x, reorder, s), std::move(shape), s);
};
std::vector<int> out_shape;
for (auto ax : a_batch) {
out_shape.push_back(a.shape(ax));
}
for (auto ax : a_concat) {
out_shape.push_back(a.shape(ax));
}
for (auto ax : b_concat) {
out_shape.push_back(b.shape(ax));
}
a = transpose_reshape(a, a_batch, a_concat, a_contract);
b = transpose_reshape(b, b_batch, b_contract, b_concat);
return reshape(matmul(a, b, s), std::move(out_shape), s);
}
// Collapse repeated subscripts and return the resulting array. The subscript
// is also updated in place. For example:
// - Given an input with shape (4, 4) and subscript "ii", returns
// the diagonal of shape (4,) and updates the subscript to "i".
// - Given an input with shape (4, 2, 4, 2) and subscript "ijij",
// returns an output with shape (4, 2) and updates the subscript
// to "ij".
array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
// Build a list of (repeat chars, num repeats)
auto& str = subscript.str;
std::vector<std::pair<char, int>> repeats;
std::string new_str;
{
std::string repeat_str;
std::string no_repeat_str;
std::unordered_map<char, int> counts;
for (int i = 0; i < str.size(); ++i) {
auto [it, _] = counts.insert({str[i], 0});
it->second++;
}
for (auto& v : counts) {
if (v.second > 1) {
repeats.emplace_back(v.first, v.second);
repeat_str += v.first;
}
}
for (auto& c : str) {
if (counts[c] == 1) {
no_repeat_str += c;
}
}
new_str = repeat_str + no_repeat_str;
}
// Build the inputs for gather
auto slice_sizes = in.shape();
std::vector<int> axes;
std::vector<array> indices;
int n_expand = repeats.size();
for (auto [c, v] : repeats) {
for (int i = 0; i < str.size(); ++i) {
if (str[i] == c) {
slice_sizes[i] = 1;
axes.push_back(i);
}
}
std::vector<int> idx_shape(n_expand--, 1);
idx_shape[0] = in.shape(axes.back());
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
for (int i = 0; i < v; ++i) {
indices.push_back(idx);
}
}
in = gather(in, indices, axes, slice_sizes, s);
// Update subscript string with removed dups
str = new_str;
// Squeeze singleton dimensions left over from the gather
for (auto& ax : axes) {
ax += indices[0].ndim();
}
return squeeze(in, axes, s);
}
// Collapse repeat indices and sum single dimensions.
// For example:
// - "aa" becomes "a"
// - "ij,jk->k" becoms "j,jk->k"
void preprocess_einsum_inputs(
std::vector<Subscript>& inputs,
const Subscript& output,
const std::vector<int>& positions,
std::vector<array>& operands,
StreamOrDevice s) {
// Collapse repeat indices
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
if (in.set.size() < in.str.size()) {
operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s);
}
}
// Sum indices that are only in a single input
{
std::unordered_map<char, int> counts;
for (auto& in : inputs) {
for (auto c : in.set) {
auto inserted = counts.insert({c, 0});
inserted.first->second++;
}
}
for (auto c : output.set) {
auto inserted = counts.insert({c, 0});
inserted.first->second++;
}
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
std::vector<int> sum_axes;
for (int ax = 0; ax < in.str.size(); ++ax) {
if (counts[in.str[ax]] == 1) {
sum_axes.push_back(ax);
}
}
if (!sum_axes.empty()) {
operands[positions[i]] =
sum(operands[positions[i]], sum_axes, false, s);
}
for (auto it = sum_axes.rbegin(); it != sum_axes.rend(); ++it) {
in.set.erase(in.str[*it]);
in.str.erase(in.str.begin() + *it);
}
}
}
}
array einsum_naive(
std::vector<Subscript> inputs,
const Subscript& output,
const std::vector<int>& positions,
std::vector<array> operands,
StreamOrDevice s) {
// Map each character to an axis
std::unordered_map<char, int> char_to_ax;
for (auto& in : inputs) {
for (auto c : in.str) {
char_to_ax.insert({c, char_to_ax.size()});
}
}
// Expand and transpose inputs as needed
for (int i = 0; i < inputs.size(); ++i) {
int pos = positions[i];
auto& op = operands[pos];
// Add missing dimensions at the end
if (op.ndim() != char_to_ax.size()) {
auto shape = op.shape();
shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1);
op = reshape(op, std::move(shape), s);
}
// Transpose:
// - Build a vector of (char, ax) pairs for the current input
// - Sort the vector by the canonical axis in char_to_ax
// - Extract the sorted axis to get transpose order
std::vector<std::pair<char, int>> str_ax;
for (auto c : inputs[i].str) {
str_ax.emplace_back(c, str_ax.size());
}
for (auto [c, ax] : char_to_ax) {
if (inputs[i].set.find(c) == inputs[i].set.end()) {
str_ax.emplace_back(c, str_ax.size());
}
}
std::sort(
str_ax.begin(),
str_ax.end(),
[&char_to_ax](const auto& x, const auto& y) {
return char_to_ax[x.first] < char_to_ax[y.first];
});
// Skip the transpose if not needed
if (std::is_sorted(
str_ax.begin(), str_ax.end(), [](const auto& x, const auto& y) {
return x.second < y.second;
})) {
continue;
}
std::vector<int> reorder;
for (auto [c, ax] : str_ax) {
reorder.push_back(ax);
}
op = transpose(op, reorder, s);
}
// Multiply and sum
auto out = operands[positions[0]];
for (int i = 1; i < positions.size(); ++i) {
out = multiply(out, operands[positions[i]], s);
}
std::vector<int> sum_axes;
for (auto [c, ax] : char_to_ax) {
if (output.set.find(c) == output.set.end()) {
sum_axes.push_back(ax);
}
}
if (!sum_axes.empty()) {
out = sum(out, sum_axes, false, s);
}
// Transpose output if needed
std::vector<int> reorder;
for (auto c : output.str) {
reorder.push_back(char_to_ax[c]);
}
for (auto& r : reorder) {
int offset = 0;
for (auto s : sum_axes) {
if (r > s) {
offset++;
}
}
r -= offset;
}
return transpose(out, reorder, s);
}
std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
const std::string& subscripts,
const std::vector<array>& operands,
const std::string& fn_name) {
if (operands.size() == 0) {
std::ostringstream msg;
msg << "[" << fn_name << "] At least one operand is required.";
throw std::invalid_argument(msg.str());
}
auto [in_subscripts, out_subscript] = parse(subscripts);
if (operands.size() != in_subscripts.size()) {
std::ostringstream msg;
msg << "[" << fn_name << "] Number of operands, " << operands.size()
<< ", does not match number of input subscripts, "
<< in_subscripts.size();
throw std::invalid_argument(msg.str());
}
auto check_letters = [&](const auto& subscript) {
for (auto c : subscript) {
if (!isalpha(c)) {
std::ostringstream msg;
msg << "[" << fn_name << "] Subscripts must be letters, but got '" << c
<< "'.";
throw std::invalid_argument(msg.str());
}
}
};
for (auto& in : in_subscripts) {
check_letters(in);
}
check_letters(out_subscript);
CharSet out_set(out_subscript.begin(), out_subscript.end());
if (out_set.size() != out_subscript.size()) {
std::ostringstream msg;
msg << "[" << fn_name << "] Repeat indices not allowed in output.";
throw std::invalid_argument(msg.str());
}
Subscript output(out_subscript, std::move(out_set));
std::unordered_map<char, int> dim_map;
std::vector<Subscript> inputs;
for (int i = 0; i < in_subscripts.size(); ++i) {
auto& in = in_subscripts[i];
CharSet in_set(in.begin(), in.end());
inputs.emplace_back(in, in_set);
if (in.size() != operands[i].ndim()) {
std::ostringstream msg;
msg << "[" << fn_name << "] Invalid number of subscripts " << in.size()
<< " for input " << i << " with " << operands[i].ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
// Check repeat subscripts are valid
if (in_set.size() < in.size()) {
std::unordered_map<char, int> local_dims;
for (int j = 0; j < in.size(); ++j) {
auto dim = operands[i].shape(j);
auto inserted = local_dims.insert({in[j], dim});
if (!inserted.second) {
if (inserted.first->second != dim) {
std::ostringstream msg;
msg << "[" << fn_name << "] Dimensions of repeated subscripts "
<< "do not have the same size (" << inserted.first->second
<< " != " << dim << ").";
throw std::invalid_argument(msg.str());
}
}
}
}
for (int j = 0; j < in.size(); j++) {
auto c = in[j];
auto dim = operands[i].shape(j);
auto inserted = dim_map.insert({c, dim});
auto& in_dim = inserted.first->second;
if (dim != 1 && in_dim != 1 && in_dim != dim) {
std::ostringstream msg;
msg << "[" << fn_name << "] Cannot broadcast dimension " << j
<< " of input " << i << " with shape " << operands[i].shape()
<< " to size " << in_dim << ".";
throw std::invalid_argument(msg.str());
}
// Ensure the broadcasted size is used
in_dim = std::max(in_dim, dim);
}
}
size_t max_size = term_size(out_subscript, dim_map);
for (auto& in : in_subscripts) {
max_size = std::max(max_size, term_size(in, dim_map));
}
PathInfo path_info;
// Get the full naive cost
std::tie(path_info.naive_cost, path_info.naive_scaling) =
compute_cost_and_scaling(inputs, output, dim_map);
// Calculate the path
std::vector<PathNode> path;
if (inputs.size() <= 2) {
std::vector<int> positions(in_subscripts.size());
std::iota(positions.begin(), positions.end(), 0);
path.emplace_back(
std::move(inputs), std::move(output), std::move(positions));
} else {
std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) =
greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size);
// Set the final output subscript to the actual output
path.back().output = std::move(output);
}
return {path, path_info};
}
} // namespace
std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
const std::string& subscripts,
const std::vector<array>& operands) {
auto [path, path_info] =
einsum_path_helper(subscripts, operands, "einsum_path");
std::vector<std::vector<int>> pos_path;
for (auto& p : path) {
pos_path.push_back(p.positions);
}
std::ostringstream path_print;
path_print << " Complete contraction: " << subscripts << "\n"
<< " Naive scaling: " << path_info.naive_scaling << "\n"
<< " Optimized scaling: " << path_info.optimized_scaling
<< "\n"
<< " Naive FLOP count: " << path_info.naive_cost << "\n"
<< " Optimized FLOP count: " << path_info.optimized_cost << "\n";
// TODO add more info here
return {pos_path, path_print.str()};
}
array einsum(
const std::string& subscripts,
const std::vector<array>& operands,
StreamOrDevice s /* = {} */) {
auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum");
auto inputs = operands;
for (auto& node : path) {
preprocess_einsum_inputs(
node.inputs, node.output, node.positions, inputs, s);
if (can_dot(node.inputs, node.output)) {
auto& in_a = node.inputs[0];
auto& in_b = node.inputs[1];
auto& out = node.output;
std::vector<int> a_contract;
std::vector<int> a_batch;
std::vector<int> a_concat;
for (int i = 0; i < in_a.str.size(); ++i) {
auto c = in_a.str[i];
if (out.set.find(c) == out.set.end()) {
// Not in the output, contraction
a_contract.push_back(i);
} else if (in_b.set.find(c) != in_b.set.end()) {
// Not a contraction but in both inputs, batch dim
a_batch.push_back(i);
} else {
// Not a batch dim or contract dim, so concat dim
a_concat.push_back(i);
}
}
std::vector<int> b_contract;
std::vector<int> b_batch;
std::vector<int> b_concat;
for (auto a_i : a_contract) {
b_contract.push_back(in_b.str.find(in_a.str[a_i]));
}
for (auto a_i : a_batch) {
b_batch.push_back(in_b.str.find(in_a.str[a_i]));
}
for (int i = 0; i < in_b.str.size(); ++i) {
auto c = in_b.str[i];
if (out.set.find(c) != out.set.end() &&
in_a.set.find(c) == in_a.set.end()) {
b_concat.push_back(i);
}
}
auto& a = inputs[node.positions[0]];
auto& b = inputs[node.positions[1]];
std::unordered_map<char, int> char_map;
for (auto i : a_batch) {
char_map.insert({in_a.str[i], char_map.size()});
}
for (auto i : a_concat) {
char_map.insert({in_a.str[i], char_map.size()});
}
for (auto i : b_concat) {
char_map.insert({in_b.str[i], char_map.size()});
}
inputs.emplace_back(batch_tensordot(
a,
b,
std::move(a_contract),
std::move(a_batch),
std::move(a_concat),
std::move(b_contract),
std::move(b_batch),
std::move(b_concat),
s));
std::vector<int> reorder;
for (auto c : node.output.str) {
reorder.push_back(char_map[c]);
}
inputs.back() = transpose(inputs.back(), reorder, s);
} else {
inputs.emplace_back(
einsum_naive(node.inputs, node.output, node.positions, inputs, s));
}
// Positions are always sorted increasing, so start from the back
for (auto it = node.positions.rbegin(); it != node.positions.rend(); ++it) {
inputs.erase(inputs.begin() + *it);
}
}
return inputs.front();
}
} // namespace mlx::core

22
mlx/einsum.h Normal file
View File

@@ -0,0 +1,22 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <string>
#include <tuple>
#include <vector>
#include "mlx/array.h"
#include "mlx/utils.h"
namespace mlx::core {
std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
const std::string& subscripts,
const std::vector<array>& operands);
array einsum(
const std::string& subscripts,
const std::vector<array>& operands,
StreamOrDevice s = {});
} // namespace mlx::core

View File

@@ -18,7 +18,7 @@ std::vector<array> Custom::vjp(
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
std::vector<array> vjp_outs;
for (int i = 0, j = 0; i < vjps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
if (j < argnums.size() && i == argnums[j]) {
vjp_outs.push_back(vjps[i]);
j++;
}
@@ -30,15 +30,16 @@ std::vector<array> Custom::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents);
std::vector<array> jvp_outs;
for (int i = 0, j = 0; i < jvps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) {
jvp_outs.push_back(jvps[i]);
j++;
std::vector<array> all_tangents;
for (int i = 0, j = 0; i < primals.size(); i++) {
if (j < argnums.size() && i == argnums[j]) {
all_tangents.emplace_back(tangents[j++]);
} else {
all_tangents.emplace_back(zeros_like(primals[i]));
}
}
return jvp_outs;
auto [_, jvps] = mlx::core::jvp(fallback_, primals, all_tangents);
return jvps;
}
std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
@@ -609,4 +610,253 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
}
array pack_and_quantize(
array& packed_w,
const array& scales,
const array& biases,
int group_size,
int bits,
const Stream& s) {
int el_per_int = 32 / bits;
array zero(0, packed_w.dtype());
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
packed_w = astype(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
zero,
n_bins),
uint32);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
return packed_w;
}
std::tuple<array, array, array>
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
auto s = to_stream(s_);
if (group_size != 32 && group_size != 64 && group_size != 128) {
std::ostringstream msg;
msg << "[quantize] The requested group size " << group_size
<< " is not supported. The supported group sizes are 64 and 128.";
throw std::invalid_argument(msg.str());
}
if (bits != 2 && bits != 4 && bits != 8) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 4 and 8.";
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
if ((w.shape(-1) % group_size) != 0) {
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
<< ". However the provided " << " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}
int el_per_int = 32 / bits;
if (w.shape(-1) < 32 * el_per_int) {
std::ostringstream msg;
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
<< "too small for quantization. We support >=512 for 2 bits, "
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
<< "shape " << w.shape() << ".";
throw std::invalid_argument(msg.str());
}
auto fallback = [group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto wshape = w.shape();
wshape.back() = -1;
array zero(0, w.dtype());
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
array eps(1e-7, w.dtype());
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales =
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
scales = where(mask, scales, negative(scales), s);
array edge = where(mask, w_min, w_max, s);
array q0 = round(divide(edge, scales, s), s);
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return {
reshape(packed_w, wshape, s),
reshape(scales, wshape, s),
reshape(biases, wshape, s),
};
};
std::vector<array> outputs;
if (s.device == Device::gpu) {
auto wq_shape = w.shape();
wq_shape.back() = w.shape(-1) / el_per_int;
auto sshape = w.shape();
sshape.back() = w.shape(-1) / group_size;
outputs = array::make_arrays(
{wq_shape, sshape, sshape},
{uint32, w.dtype(), w.dtype()},
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w});
} else {
outputs = fallback({w});
}
return {outputs[0], outputs[1], outputs[2]};
}
array affine_quantize(
const array& w,
const array& scales,
const array& biases,
int group_size,
int bits,
StreamOrDevice s_) {
auto s = to_stream(s_);
int el_per_int = 32 / bits;
auto fallback = [group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto scales = expand_dims(inputs[1], -1, s);
auto biases = expand_dims(inputs[2], -1, s);
auto wshape = w.shape();
wshape.back() = -1;
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return {reshape(packed_w, wshape, s)};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) / el_per_int;
return array(
out_shape,
uint32,
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w, scales, biases});
}
return fallback({w, scales, biases})[0];
}
array affine_dequantize(
const array& w,
const array& scales,
const array& biases,
int group_size,
int bits,
StreamOrDevice s_) {
if (bits <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for bits: " << bits;
throw std::invalid_argument(msg.str());
}
if (group_size <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for group_size: " << group_size;
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
auto wshape = w.shape();
auto sshape = scales.shape();
auto bshape = biases.shape();
wshape.back() = -1;
sshape.back() = -1;
bshape.back() = -1;
if (wshape != sshape || wshape != bshape) {
throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix");
}
if (w.dtype() != uint32) {
throw std::invalid_argument(
"[dequantize] The matrix should be given as a uint32");
}
// Packing into uint32
int el_per_int = 32 / bits;
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[dequantize] Shape of scales and biases does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape "
<< w.shape() << " and scales/biases of shape " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits << ".";
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
auto fallback =
[&wshape, &sshape, &scales, &biases, group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto& scales = inputs[1];
auto& biases = inputs[2];
std::vector<array> parts;
for (int start = 0; start < 32; start += bits) {
int shift_left = 32 - (start + bits);
int shift_right = shift_left + start;
parts.push_back(expand_dims(
right_shift(
left_shift(w, array(32 - (start + bits), uint32), s),
array(32 - bits, uint32),
s),
-1,
s));
}
array w_full = concatenate(parts, -1, s);
// Dequantize
wshape.push_back(group_size);
w_full = reshape(w_full, wshape, s);
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
w_full = add(w_full, expand_dims(biases, -1, s), s);
w_full = reshape(w_full, sshape, s);
return {w_full};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) * el_per_int;
return array(
out_shape,
scales.dtype(),
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
{w, scales, biases});
}
return fallback({w, scales, biases})[0];
}
} // namespace mlx::core::fast

Some files were not shown because too many files have changed in this diff Show More