Compare commits

..

47 Commits

Author SHA1 Message Date
Awni Hannun
9231617eb3 Move to nanobind v2 (#1316) 2024-08-08 17:17:46 -07:00
Alex Barron
32668a7317 CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv (#1307)
* add cholesky inv + tri inv

* always run tri_inv on cpu

* consistent naming
2024-08-08 15:18:02 -07:00
Angelos Katharopoulos
780c197f95 Fix test tolerance and patch bump (#1315) 2024-08-08 14:51:09 -07:00
Angelos Katharopoulos
eb8819e91e Revert variance to be numerically stable (#1314) 2024-08-08 13:35:02 -07:00
Awni Hannun
30bbea2f08 Add gemv masked to JIT plus some fixes (#1310)
* add gemv masked to JIT plus some fixes

* some cleanup

* add utils

* fix

* fix 2

* more cleaning

* fix

* remove unused mps matmul support

* one more nit

* revert
2024-08-07 13:38:07 -07:00
Alex Barron
635ccd9e25 Add "edge" mode to mx.pad (#1309)
* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
2024-08-06 11:23:10 -07:00
nicolov
8c9f0278b9 Add vmap to scatter (#1200)
* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-08-05 20:12:27 -07:00
Awni Hannun
58d0e199e1 add bfloat conv for windograd (#1306)
* add bfloat conv for windograd

* accumulate in fp32

* accumulate in fp32

* accumulate in bf16
2024-08-05 15:51:13 -07:00
Awni Hannun
10b5835501 fix creating array from bf16 tensors in jax / torch (#1305) 2024-08-01 16:20:51 -07:00
Awni Hannun
6c8dd307eb faster group norm (#1304) 2024-08-01 12:49:23 -07:00
Awni Hannun
43ffdab172 fix rope and random (#1301)
* fix rope and random

* comment
2024-07-31 16:18:25 -07:00
Awni Hannun
40b6d67333 Fixes for large arrays with a few ops (#1299)
* fixes for large arrays with a few ops

* fix bug

* fix all of copy
2024-07-30 17:18:39 -07:00
Alex Barron
c52d1600f0 Fused Affine Quantize/Dequantize ops (#1282)
* Add fast affine dequantize

* add full quantize kernel

* fused kernel with scale/bias computation

* fix docstring

* fix no jit error

* fix test

* test fix

* reduce fast api to only affine_quantize
2024-07-29 15:11:38 -07:00
Awni Hannun
aa1d6cadad Fix docs latex build and nits (#1297)
* fix docs latex build and nits

* fix stub gen and try to clean up building
2024-07-29 11:44:06 -07:00
Atakan Tekparmak
6e06e3a904 feat: Added "tanh" option to GELU approximation (#1268) 2024-07-28 09:07:56 +02:00
Yaroslav
8cfb9fc0b8 Update requirements.txt (#1291) 2024-07-26 12:59:52 -07:00
Awni Hannun
7b456fd2c0 Array api (#1289)
* some updates for numpy 2.0 and array api

* some updates for numpy 2.0 and array api

* fix array api doc
2024-07-26 10:40:49 -07:00
Awni Hannun
e9e53856d2 patch bump (#1287) 2024-07-25 11:42:09 -07:00
Anton Belov
5029894662 [Issue #1187] Add nan_to_num function initial attempt (#1247)
* initial attempt, working with wrong types

* not compiling; mx.float16 and mx.bfloat16 tests added

* fix nan to num

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-25 09:57:37 -07:00
Awni Hannun
baf9fa5f42 Einsum (#1269)
* einsum initial

* fix comma break

* sum axis was wrong

* small cleanups

* python binding

* changed bindings to resemble numpy

* remove todo comment

* comment changes

* add count of operands/inputs

* fail fast if operands list is empty

* ignore comma if no output

* einsum path matching numpy

* getting somewhere with path

* remove print

* it passes the first test

* moved einsum tests to seperate file

* seperated einsum path

* moved einsum naive

* remove space from equation

* fast fail if no operands passed

* update tests and remove printf

* small cleanup

* some more cleanups

* removed python helper file

* ack

* utilize std for finding min in vector

* duplicate def

* remove the tuple as it was unreadable

* moved einsum_naive back to ops

* remaining isn't needed

* avoid creating another set

* cleanup

* greedy path, start of naive einsum

* more einsum

* fix some bugs

* some more fixes, tests pass

* benchmark

* some simplify

* fix einsum and test

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>

* add a bunch more tests and fix a bunch more bugs

* some docs nits

---------

Co-authored-by: dc-dc-dc <dgcruz983@gmail.com>
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-07-25 09:36:44 -07:00
Jagrit Digani
7f914365fd Fix GPU sort for large arrays (#1285)
* Fix GPU sort for large arrays
2024-07-24 14:37:10 -07:00
Paul Paczuski
ebd7135b50 Improve stability of BCE loss calculation for input probabilities close to or exactly 0 or 1 (#1280)
* Improve stability of BCE loss calculation

* Standardize comment

* Apply formatting with black via pre-commit

* Add usage recommendation to docstring

* Update python/mlx/nn/losses.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-07-24 08:38:22 -07:00
fgranqvist
50eff6a10a Implement sampling from laplace distribution. (#1279) 2024-07-24 15:15:37 +02:00
Alex Barron
c34a5ae7f7 Fix bfloat16 Hadamard (#1283)
* fix bfloat16 hadamard

* add scale

* review comments

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-07-23 14:54:43 -07:00
Awni Hannun
e2aa6ec8ae some fixes (#1281) 2024-07-23 11:49:05 -07:00
toji
6768c6a54a Adding missing type hints (#1243)
* added type hints for `run`, `tree_map` and `tree_map_with_path`

* fix lint

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-23 07:29:38 -07:00
Tim Gymnich
6307d166eb Fix overflow / underflow handling for expm1f (#1278)
* Fix overflow / underflow handling for expm1f

* update tests
2024-07-23 07:29:06 -07:00
Awni Hannun
1fba87b0df Fix leak with multi-output primitives (#1274)
* fix leak with multi-output primitives

* hopefully an actual fix
2024-07-23 06:34:18 -07:00
Awni Hannun
df124e018a fix gguf (#1273)
* fix gguf

* comment
2024-07-18 07:35:35 -07:00
Cheng
2f83d6e4b7 Do not release buffers on exit (#1142) 2024-07-15 15:12:24 -07:00
Feng Shijie
987785d8d7 Fix typo and missing header (#1266) 2024-07-15 08:20:24 -07:00
Awni Hannun
8c01a7893b minor fix in optimizer + docs (#1264) 2024-07-12 12:18:02 -07:00
Awni Hannun
218047c75a docs fixes (#1263) 2024-07-11 15:59:07 -07:00
Alex Barron
d0da74209b version bump (#1260) 2024-07-11 11:17:55 -07:00
Angelos Katharopoulos
5c1fa64fb0 Custom transforms (#1246) 2024-07-10 18:00:01 -07:00
Alex Barron
a3c287354f Fast Hadamard Transform (#1249)
* Working hadamard for powers of 2

* working for m*2^k

* add scale and check contiguity

* add size check

* clean up

* fix test

* add grads + vmap

* gpu only

* skip on linux

* test typo

* add cpu impl

* remove gpu only tests

* fix linux build + add is_equivalent
2024-07-09 20:39:01 -07:00
Angelos Katharopoulos
03cf033f82 Fix reshape copy bug (#1253) 2024-07-07 21:37:00 -07:00
Alex Barron
bdb36c9a63 add zero vjps for bitwise ops and gather w.r.t. index (#1256) 2024-07-07 21:34:59 -07:00
Awni Hannun
20bb301195 CPU binary reduction + Nits (#1242)
* very minor nits

* reduce binary

* fix test
2024-06-28 13:50:42 -07:00
Awni Hannun
d6383a1c6a version bump (#1239) 2024-06-27 10:43:13 -07:00
Angelos Katharopoulos
b05bcfd27f Fixes segfault when compiling checkpointed functions (#1235) 2024-06-26 16:14:45 -07:00
Alex Barron
2615660e62 Fix strided sort bug (#1236)
* Use output strides in sort kernel

* fix zero strides bug
2024-06-26 14:32:11 -07:00
Awni Hannun
5b0af4cdb1 fix donation condition for compilation (#1237) 2024-06-26 09:04:05 -07:00
Jagrit Digani
8c2e15e6c8 Accelerate import updates for iOS (#1227)
* Update veclib and bnns includes to #include <Accelerate/Accelerate.h> for compatibility with ios

* Mark float literals in softmax.cpp to be float16_t for errors in ios

* Add arm neon vector operation guards

* Redirect to common backend for consistency
2024-06-26 09:01:50 -07:00
Awni Hannun
56c8a33439 Get metal version from xcode (#1228)
* get metal version from xcode

* typo

* fix
2024-06-26 07:02:11 -07:00
David Koski
4eef1e8a3e fix typo (#1215) 2024-06-24 13:36:35 -07:00
Alex Barron
95d11bda06 Fix NumPy 2.0 pickle test (#1221)
* fix numpy version <2 temporarily

* typo

* better fix

* Fix just for bfloat16

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-23 05:47:22 -07:00
157 changed files with 7281 additions and 3195 deletions

View File

@@ -10,13 +10,14 @@ MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Juarez Bochi: Fixed bug in cross attention. - Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION) if(NOT MLX_VERSION)
set(MLX_VERSION 0.15.1) set(MLX_VERSION 0.16.2)
endif() endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
@@ -83,18 +83,17 @@ elseif (MLX_BUILD_METAL)
OUTPUT_VARIABLE MACOS_VERSION OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY) COMMAND_ERROR_IS_FATAL ANY)
if (${MACOS_VERSION} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip) set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
if (${MACOS_VERSION} GREATER_EQUAL 15.0) # Get the metal version
set(MLX_METAL_VERSION METAL_3_2) execute_process(
elseif (${MACOS_VERSION} GREATER_EQUAL 14.2) COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
set(MLX_METAL_VERSION METAL_3_1) OUTPUT_VARIABLE MLX_METAL_VERSION
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0) COMMAND_ERROR_IS_FATAL ANY)
set(MLX_METAL_VERSION METAL_3_0)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
FetchContent_Declare( FetchContent_Declare(
metal_cpp metal_cpp
@@ -113,7 +112,7 @@ elseif (MLX_BUILD_METAL)
${FOUNDATION_LIB} ${FOUNDATION_LIB}
${QUARTZ_LIB}) ${QUARTZ_LIB})
add_compile_definitions(${MLX_METAL_VERSION}) add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif() endif()
if (MLX_BUILD_CPU) if (MLX_BUILD_CPU)
@@ -170,11 +169,18 @@ if (MPI_FOUND)
execute_process( execute_process(
COMMAND zsh "-c" "mpirun --version" COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION OUTPUT_VARIABLE MPI_VERSION
COMMAND_ERROR_IS_FATAL ANY ERROR_QUIET
) )
if (${MPI_VERSION} MATCHES ".*Open MPI.*") if (${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH}) target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif (MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING
"MPI found but mpirun is not available. Building without MPI."
)
else() else()
set(MPI_FOUND FALSE)
message( message(
WARNING WARNING
"MPI which is not OpenMPI found. Building without MPI." "MPI which is not OpenMPI found. Building without MPI."

View File

@@ -0,0 +1,84 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
import numpy as np
def timeit(fn, its=100, args=[]):
for _ in range(5):
fn(*args)
tic = time.perf_counter()
for _ in range(its):
fn(*args)
toc = time.perf_counter()
return 1e3 * (toc - tic) / its
def time_little_einsum_path():
subscripts = "ik,kj->ij"
x = mx.ones((32, 32))
y = mx.ones((32, 32))
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
x = np.array(x)
y = np.array(y)
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
print("Timing little einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_big_einsum_path():
chars = list("abcdefgh")
char_to_dim = {c: v for v, c in enumerate(chars)}
num_inputs = 10
inputs = []
subscripts = []
for _ in range(num_inputs):
subscript = np.random.choice(chars, size=5, replace=False).tolist()
subscripts.append("".join(subscript))
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
subscripts = ",".join(subscripts)
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
inputs = [mx.array(x) for x in inputs]
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
print("Timing big einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_attention():
def regular_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
scores = mx.softmax(scores, axis=-1)
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
mx.eval(output)
def einsum_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
scores = mx.softmax(scores, axis=-1)
output = mx.einsum("ijtu,iujk->itjk", scores, values)
mx.eval(output)
x = mx.random.uniform(shape=(8, 512, 32, 128))
regular_time = timeit(regular_attention, args=(x,))
ein_time = timeit(einsum_attention, args=(x,))
print("Timing einsum attention...")
print(f"Regular ... {regular_time:.3f} ms")
print(f"Einsum ... {ein_time:.3f} ms")
if __name__ == "__main__":
time_little_einsum_path()
time_big_einsum_path()
time_attention()

View File

@@ -0,0 +1,70 @@
import argparse
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def had(x):
y = mx.hadamard_transform(x)
mx.eval(y)
def copy(x):
y = x + 1.0
mx.eval(y)
def run(dtype):
system_size = 2**26
outputs = {}
for test_fn in (had, copy):
for m in [1, 12, 20, 28]:
if test_fn == copy:
key = "copy"
elif m == 1:
key = "had_2^k"
else:
key = "had_m*2^k"
outputs.setdefault(key, {})
for k in range(7, 14):
n = m * 2**k
if n > 2**15:
continue
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
x = mx.array(x_np)
runtime_ms = measure_runtime(test_fn, x=x)
bytes_per_gb = 1e9
ms_per_s = 1e3
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
bandwidth_gb = (
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
)
print(n, bandwidth_gb)
outputs[key][n] = bandwidth_gb
colors = {
"copy": "black",
"had_2^k": "steelblue",
"had_m*2^k": "skyblue",
}
for key, output in outputs.items():
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"bench_{dtype.__name__}.png")
plt.clf()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
dtype = np.float16 if args.fp16 else np.float32
run(dtype)

View File

@@ -1,3 +1,4 @@
sphinx sphinx
breathe breathe
sphinx-book-theme sphinx-book-theme
mlx

View File

@@ -83,3 +83,15 @@ def setup(app):
# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")] latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
latex_elements = {
"preamble": r"""
\usepackage{enumitem}
\setlistdepth{5}
\setlist[itemize,1]{label=$\bullet$}
\setlist[itemize,2]{label=$\bullet$}
\setlist[itemize,3]{label=$\bullet$}
\setlist[itemize,4]{label=$\bullet$}
\setlist[itemize,5]{label=$\bullet$}
\renewlist{itemize}{itemize}{5}
""",
}

View File

@@ -486,9 +486,8 @@ below.
std::ostringstream kname; std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out); kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available and look for it // Make sure the metal library is available
// in the same folder as this executable if needed d.register_library("mlx_ext");
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), "mlx_ext");

View File

@@ -15,7 +15,7 @@ module to concisely define the model architecture.
Attention layer Attention layer
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^
We will start with the llama attention layer which notably uses the RoPE We will start with the Llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to key/value cache that will be concatenated with the provided keys and values to
support efficient inference. support efficient inference.

View File

@@ -64,7 +64,7 @@ set:
Next, setup the problem parameters and load the data. To load the data, you need our Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader `mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which <https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as `mnist`. we will import as ``mnist``.
.. code-block:: python .. code-block:: python

View File

@@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
.. code-block:: shell
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
Then simply build and install MLX using pip: Then simply build and install MLX using pip:
.. code-block:: shell .. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install . CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
For developing use an editable install: For developing, install the package with development dependencies, and use an
editable install:
.. code-block:: shell .. code-block:: shell
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e . CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e ".[dev]"
To make sure the install is working run the tests with: Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext -j --inplace
Run the tests with:
.. code-block:: shell .. code-block:: shell
pip install ".[testing]"
python -m unittest discover python/tests python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your IDE: Optional: Install stubs to enable auto completions and type checking from your
IDE:
.. code-block:: shell .. code-block:: shell
pip install ".[dev]"
python setup.py generate_stubs python setup.py generate_stubs
C++ API C++ API
@@ -195,7 +195,7 @@ GGUF, you can do:
.. code-block:: shell .. code-block:: shell
cmake .. cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \ -DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \ -DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \ -DMLX_BUILD_CPU=OFF \

View File

@@ -24,6 +24,7 @@ Array
array.any array.any
array.argmax array.argmax
array.argmin array.argmin
array.conj
array.cos array.cos
array.cummax array.cummax
array.cummin array.cummin
@@ -57,3 +58,4 @@ Array
array.transpose array.transpose
array.T array.T
array.var array.var
array.view

View File

@@ -9,7 +9,9 @@ Linear Algebra
:toctree: _autosummary :toctree: _autosummary
inv inv
tri_inv
norm norm
cholesky cholesky
cholesky_inv
qr qr
svd svd

View File

@@ -57,6 +57,8 @@ Operations
diagonal diagonal
divide divide
divmod divmod
einsum
einsum_path
equal equal
erf erf
erfinv erfinv
@@ -72,6 +74,7 @@ Operations
gather_qmm gather_qmm
greater greater
greater_equal greater_equal
hadamard_transform
identity identity
inner inner
isclose isclose
@@ -103,6 +106,7 @@ Operations
minimum minimum
moveaxis moveaxis
multiply multiply
nan_to_num
negative negative
not_equal not_equal
ones ones

View File

@@ -31,6 +31,41 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state. # Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state) mx.eval(model.parameters(), optimizer.state)
Saving and Loading
------------------
To serialize an optimizer, save its state. To load an optimizer, load and set
the saved state. Here's a simple example:
.. code-block:: python
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim
optimizer = optim.Adam(learning_rate=1e-2)
# Perform some updates with the optimizer
model = {"w" : mx.zeros((5, 5))}
grads = {"w" : mx.ones((5, 5))}
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
parameters are not. A good rule of thumb is if the parameter can be scheduled
then it will be included in the optimizer state.
.. toctree:: .. toctree::
optimizers/optimizer optimizers/optimizer

View File

@@ -44,3 +44,4 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
split split
truncated_normal truncated_normal
uniform uniform
laplace

View File

@@ -10,6 +10,7 @@ Transforms
eval eval
compile compile
custom_function
disable_compile disable_compile
enable_compile enable_compile
grad grad

View File

@@ -249,9 +249,8 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_"); kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out); kname << type_to_name(out);
// Make sure the metal library is available and look for it // Make sure the metal library is available
// in the same folder as this executable if needed d.register_library("mlx_ext");
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library // Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext"); auto kernel = d.get_kernel(kname.str(), "mlx_ext");

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.24 cmake>=3.24
mlx>=0.9.0 mlx>=0.16.2
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 nanobind==2.0

View File

@@ -6,6 +6,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp

View File

@@ -17,6 +17,10 @@ bool in_tracing() {
return detail::InTracing::in_tracing(); return detail::InTracing::in_tracing();
} }
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace } // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */) array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@@ -102,7 +106,7 @@ void array::eval() {
} }
bool array::is_tracer() const { bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing(); return array_desc_->is_tracer && in_tracing() || retain_graph();
} }
void array::set_data(allocator::Buffer buffer, deleter_t d) { void array::set_data(allocator::Buffer buffer, deleter_t d) {
@@ -171,10 +175,11 @@ array::~array() {
return; return;
} }
// Ignore arrays that will be detached // Ignore arrays that might be detached during eval
if (status() != array::Status::unscheduled) { if (status() == array::Status::scheduled) {
return; return;
} }
// Break circular reference for non-detached arrays with siblings // Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) { if (auto n = siblings().size(); n > 0) {
bool do_detach = true; bool do_detach = true;

View File

@@ -1,9 +1,9 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cassert> #include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h> #include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@@ -2,8 +2,7 @@
#include <cassert> #include <cassert>
#include <vecLib/BNNS/bnns.h> #include <Accelerate/Accelerate.h>
#include <vecLib/cblas_new.h>
#include "mlx/backend/accelerate/utils.h" #include "mlx/backend/accelerate/utils.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"

View File

@@ -3,8 +3,7 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <vecLib/vDSP.h> #include <Accelerate/Accelerate.h>
#include <vecLib/vForce.h>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/binary.h" #include "mlx/backend/common/binary.h"
@@ -37,7 +36,7 @@ DEFAULT(Ceil)
DEFAULT(Concatenate) DEFAULT(Concatenate)
DEFAULT(Conjugate) DEFAULT(Conjugate)
DEFAULT(Copy) DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends) DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod) DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements) DEFAULT(NumberOfElements)
@@ -51,6 +50,7 @@ DEFAULT(GatherMM)
DEFAULT(GatherQMM) DEFAULT(GatherQMM)
DEFAULT(Greater) DEFAULT(Greater)
DEFAULT(GreaterEqual) DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less) DEFAULT(Less)
DEFAULT(LessEqual) DEFAULT(LessEqual)
DEFAULT(Load) DEFAULT(Load)
@@ -102,7 +102,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@@ -117,7 +117,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
}); });
} else if (a.dtype() == int32) { } else if (a.dtype() == int32) {
binary( binary_op<int>(
a, a,
b, b,
out, out,
@@ -132,7 +132,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n); vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
}); });
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x + y; }); eval(inputs, out);
} }
} }
@@ -287,7 +287,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == int32) { if (a.dtype() == int32) {
binary( binary_op<int>(
a, a,
b, b,
out, out,
@@ -300,7 +300,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n); vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
}); });
} else if (a.dtype() == float32) { } else if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@@ -315,7 +315,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
}); });
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x / y; }); eval(inputs, out);
} }
} }
@@ -326,12 +326,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out); set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else { } else {
throw std::invalid_argument( eval(inputs, out);
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
} }
} }
@@ -393,12 +389,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size(); auto size = in.data_size();
vvlog1pf( vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else { } else {
throw std::invalid_argument( eval(inputs, out);
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
} }
} }
@@ -408,7 +400,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@@ -423,7 +415,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
}); });
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x * y; }); eval(inputs, out);
} }
} }
@@ -434,7 +426,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out); set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size()); vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else { } else {
unary(in, out, [](auto x) { return -x; }); eval(inputs, out);
} }
} }
@@ -521,7 +513,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size(); auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size); vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else { } else {
unary(in, out, [](auto x) { return x * x; }); eval(inputs, out);
} }
} }
@@ -547,7 +539,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1]; auto& b = inputs[1];
if (a.dtype() == float32) { if (a.dtype() == float32) {
binary( binary_op<float>(
a, a,
b, b,
out, out,
@@ -565,7 +557,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
}); });
} else if (a.dtype() == int32) { } else if (a.dtype() == int32) {
binary( binary_op<int>(
a, a,
b, b,
out, out,
@@ -577,7 +569,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
}, },
UseDefaultBinaryOp()); UseDefaultBinaryOp());
} else { } else {
binary(a, b, out, [](auto x, auto y) { return x - y; }); eval(inputs, out);
} }
} }

View File

@@ -2,8 +2,8 @@
#include <cassert> #include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h> #include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/reduce.h" #include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"

View File

@@ -3,7 +3,10 @@
#include <cassert> #include <cassert>
#include <limits> #include <limits>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_neon.h> #include <arm_neon.h>
#endif
#include <simd/math.h> #include <simd/math.h>
#include <simd/vector.h> #include <simd/vector.h>
@@ -53,25 +56,26 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
return (*(simd_float16*)&epart) * x; return (*(simd_float16*)&epart) * x;
} }
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/** /**
* The ARM neon equivalent of the fast exp above. * The ARM neon equivalent of the fast exp above.
*/ */
inline float16x8_t neon_fast_exp(float16x8_t x) { inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e) x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14 x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14 x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5))); float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
float16x8_t fpart = vsubq_f16(x, ipart); float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(1.535336188319500e-4f); x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart); x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
// generate 2**ipart in the floating point representation using integer // generate 2**ipart in the floating point representation using integer
// bitshifting // bitshifting
@@ -107,53 +111,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
return vget_lane_f16(y, 0); return vget_lane_f16(y, 0);
} }
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
}
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename VT> template <typename T, typename VT>
struct NeonFp16SimdOps { struct NeonFp16SimdOps {
VT init(T a) { VT init(T a) {
@@ -201,6 +158,55 @@ struct NeonFp16SimdOps {
} }
}; };
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
return a;
}
VT load(const T* a) {
return *(VT*)a;
}
void store(T* dst, VT x) {
*(VT*)dst = x;
}
VT max(VT a, VT b) {
return simd_max(a, b);
}
VT exp(VT x) {
return simd_fast_exp(x);
}
VT add(VT a, VT b) {
return a + b;
}
VT sub(VT a, T b) {
return a - b;
}
VT mul(VT a, VT b) {
return a * b;
}
VT mul(VT a, T b) {
return a * b;
}
T reduce_max(VT x) {
return simd_reduce_max(x);
}
T reduce_add(VT x) {
return simd_reduce_add(x);
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N> template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) { void softmax(const array& in, array& out) {
Ops ops; Ops ops;
@@ -362,12 +368,16 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
AccelerateSimdOps<float, simd_float16>, AccelerateSimdOps<float, simd_float16>,
16>(in, out); 16>(in, out);
} else { } else {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
softmax< softmax<
float16_t, float16_t,
float16_t, float16_t,
float16x8_t, float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>, NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out); 8>(in, out);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
eval(inputs, out); // Redirect to common backend for consistency
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} }
break; break;
case bfloat16: case bfloat16:

View File

@@ -1,8 +1,8 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#include <vecLib/BNNS/bnns.h> #include <Accelerate/Accelerate.h>
#include "mlx/dtype.h" #include "mlx/dtype.h"
namespace mlx::core { namespace mlx::core {

View File

@@ -42,6 +42,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp

View File

@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(inputs[0]); out.copy_shared_buffer(inputs[0]);
} }
void CustomVJP::eval( void CustomTransforms::eval(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());

View File

@@ -205,8 +205,8 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Correct size // - Correct size
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) { if (move_buffers) {
outputs[o].move_shared_buffer( outputs[o].move_shared_buffer(

View File

@@ -4,6 +4,7 @@
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -142,29 +143,31 @@ void copy_general(
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides, const std::vector<stride_t>& i_strides,
int64_t i_offset) { int64_t i_offset) {
switch (src.ndim()) { auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides});
switch (new_shape.size()) {
case 1: case 1:
copy_general_dim1<SrcT, DstT, stride_t>( copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset); src, dst, new_shape, new_strides[0], i_offset);
return; return;
case 2: case 2:
copy_general_dim2<SrcT, DstT, stride_t>( copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset); src, dst, new_shape, new_strides[0], i_offset);
return; return;
case 3: case 3:
copy_general_dim3<SrcT, DstT, stride_t>( copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset); src, dst, new_shape, new_strides[0], i_offset);
return; return;
case 4: case 4:
copy_general_dim4<SrcT, DstT, stride_t>( copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, data_shape, i_strides, i_offset); src, dst, new_shape, new_strides[0], i_offset);
return; return;
} }
auto src_ptr = src.data<SrcT>() + i_offset; auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>(); auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) { for (size_t i = 0; i < dst.size(); ++i) {
stride_t src_elem = elem_to_loc(i, data_shape, i_strides); stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]); dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
} }
} }
@@ -195,10 +198,10 @@ inline void copy_general_general_dims(
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides, const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides, const std::vector<stride_t>& o_strides,
stride_t i_offset, int64_t i_offset,
stride_t o_offset) { int64_t o_offset) {
if constexpr (D > 1) { if constexpr (D > 1) {
int axis = src.ndim() - D; int axis = data_shape.size() - D;
auto stride_src = i_strides[axis]; auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis]; auto stride_dst = o_strides[axis];
auto N = data_shape[axis]; auto N = data_shape[axis];
@@ -209,7 +212,7 @@ inline void copy_general_general_dims(
o_offset += stride_dst; o_offset += stride_dst;
} }
} else { } else {
int axis = src.ndim() - 1; int axis = data_shape.size() - 1;
auto stride_src = i_strides[axis]; auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis]; auto stride_dst = o_strides[axis];
auto N = data_shape[axis]; auto N = data_shape[axis];
@@ -230,38 +233,76 @@ void copy_general_general(
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides, const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides, const std::vector<stride_t>& o_strides,
stride_t i_offset, int64_t i_offset,
stride_t o_offset) { int64_t o_offset) {
switch (src.ndim()) { auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides, o_strides});
switch (new_shape.size()) {
case 1: case 1:
copy_general_general_dims<SrcT, DstT, stride_t, 1>( copy_general_general_dims<SrcT, DstT, stride_t, 1>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return; return;
case 2: case 2:
copy_general_general_dims<SrcT, DstT, stride_t, 2>( copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return; return;
case 3: case 3:
copy_general_general_dims<SrcT, DstT, stride_t, 3>( copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return; return;
case 4: case 4:
copy_general_general_dims<SrcT, DstT, stride_t, 4>( copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return; return;
case 5: case 5:
copy_general_general_dims<SrcT, DstT, stride_t, 5>( copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
return; return;
} }
int size = std::accumulate( int size = std::accumulate(
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>()); new_shape.end() - 5, new_shape.end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) { for (int i = 0; i < src.size(); i += size) {
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides); stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]);
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides); stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]);
copy_general_general_dims<SrcT, DstT, stride_t, 5>( copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset); src,
dst,
new_shape,
new_strides[0],
new_strides[1],
src_offset,
dst_offset);
} }
} }
@@ -444,8 +485,17 @@ void copy_inplace(
} }
} }
template <> template void copy_inplace<size_t>(
void copy_inplace<int64_t>( const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const std::vector<int>& data_shape,
@@ -453,24 +503,6 @@ void copy_inplace<int64_t>(
const std::vector<int64_t>& o_strides, const std::vector<int64_t>& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype) { CopyType ctype);
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -52,7 +52,7 @@ DEFAULT(Convolution)
DEFAULT(Copy) DEFAULT(Copy)
DEFAULT(Cos) DEFAULT(Cos)
DEFAULT(Cosh) DEFAULT(Cosh)
DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(Depends) DEFAULT_MULTI(Depends)
DEFAULT(Divide) DEFAULT(Divide)
DEFAULT(NumberOfElements) DEFAULT(NumberOfElements)
@@ -68,6 +68,7 @@ DEFAULT(Full)
DEFAULT(Gather) DEFAULT(Gather)
DEFAULT(Greater) DEFAULT(Greater)
DEFAULT(GreaterEqual) DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less) DEFAULT(Less)
DEFAULT(LessEqual) DEFAULT(LessEqual)
DEFAULT(Load) DEFAULT(Load)

View File

@@ -0,0 +1,107 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/primitives.h"
namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(array& out, int n, int m, float scale) {
for (int b = 0; b < out.size() / n; b++) {
size_t loc = b * n;
T* data_ptr = out.data<T>() + loc;
int h = 1;
int n_over_2 = n / 2;
while (h < n) {
for (int i = 0; i < n / 2; i++) {
int k = i & (h - 1);
int j = ((i - k) << 1) + k;
float x = *(data_ptr + j);
float y = *(data_ptr + j + h);
*(data_ptr + j) = x + y;
*(data_ptr + j + h) = x - y;
if (h == n_over_2) {
*(data_ptr + j) *= scale;
*(data_ptr + j + h) *= scale;
}
}
h <<= 1;
}
}
}
// m component
template <typename T>
void hadamard_m(array& out, int n, int m, float scale) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
auto end = matrix.find('\n', start);
std::vector<bool> hmat_vec;
while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
hmat_vec.push_back(row[i] == '+');
}
start = end + 1;
end = matrix.find('\n', start);
}
for (int b = 0; b < out.size() / m / n; b++) {
size_t loc = b * n * m;
T* data_ptr = out.data<T>() + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);
for (int j = 0; j < m; j++) {
for (int k = 0; k < m; k++) {
float x = *(data_ptr + i + k * n);
if (hmat_vec[k + j * m]) {
out[j] += x;
} else {
out[j] -= x;
}
}
}
for (int j = 0; j < m; j++) {
*(data_ptr + i + j * n) = out[j] * scale;
}
}
}
}
template <typename T>
void hadamard(array& out, int n, int m, float scale) {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out, n, m, n_scale);
if (m > 1) {
hadamard_m<T>(out, n, m, scale);
}
}
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Copy input to output
copy(in, out, CopyType::General);
int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) {
case float32:
return hadamard<float>(out, n, m, scale_);
case float16:
return hadamard<float16_t>(out, n, m, scale_);
case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_);
default:
throw std::invalid_argument("[hadamard] Unsupported type.");
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,105 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <map>
#include "mlx/utils.h"
namespace mlx::core {
// From http://neilsloane.com/hadamard/
constexpr std::string_view h12 = R"(
+-++++++++++
--+-+-+-+-+-
+++-++----++
+---+--+-++-
+++++-++----
+-+---+--+-+
++--+++-++--
+--++---+--+
++----+++-++
+--+-++---+-
++++----+++-
+-+--+-++---
)";
constexpr std::string_view h20 = R"(
+----+----++--++-++-
-+----+---+++---+-++
--+----+---+++-+-+-+
---+----+---+++++-+-
----+----++--++-++-+
-+++++-----+--+++--+
+-+++-+---+-+--+++--
++-++--+---+-+--+++-
+++-+---+---+-+--+++
++++-----++--+-+--++
--++-+-++-+-----++++
---++-+-++-+---+-+++
+---++-+-+--+--++-++
++---++-+----+-+++-+
-++---++-+----+++++-
-+--+--++-+----+----
+-+-----++-+----+---
-+-+-+---+--+----+--
--+-+++------+----+-
+--+--++------+----+
)";
constexpr std::string_view h28 = R"(
+------++----++-+--+-+--++--
-+-----+++-----+-+--+-+--++-
--+-----+++---+-+-+----+--++
---+-----+++---+-+-+-+--+--+
----+-----+++---+-+-+++--+--
-----+-----++++--+-+--++--+-
------++----++-+--+-+--++--+
--++++-+-------++--+++-+--+-
---++++-+-----+-++--+-+-+--+
+---+++--+----++-++--+-+-+--
++---++---+----++-++--+-+-+-
+++---+----+----++-++--+-+-+
++++--------+-+--++-++--+-+-
-++++--------+++--++--+--+-+
-+-++-++--++--+--------++++-
+-+-++--+--++--+--------++++
-+-+-++--+--++--+----+---+++
+-+-+-++--+--+---+---++---++
++-+-+-++--+------+--+++---+
-++-+-+-++--+------+-++++---
+-++-+---++--+------+-++++--
-++--++-+-++-+++----++------
+-++--++-+-++-+++-----+-----
++-++---+-+-++-+++-----+----
-++-++-+-+-+-+--+++-----+---
--++-++++-+-+----+++-----+--
+--++-+-++-+-+----+++-----+-
++--++-+-++-+-+----++------+
)";
inline const std::map<int, std::string_view> hadamard_matrices() {
return {{12, h12}, {20, h20}, {28, h28}};
}
inline std::pair<int, int> decompose_hadamard(int n) {
// n = m*2^k
int m = 1;
if (!is_power_of_2(n)) {
auto h_matrices = hadamard_matrices();
for (auto [factor, _] : h_matrices) {
if (n % factor == 0) {
m = factor;
n /= factor;
break;
}
}
if (m == 1) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
return {n, m};
}
} // namespace mlx::core

View File

@@ -10,9 +10,106 @@
#include <lapack.h> #include <lapack.h>
#endif #endif
// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1),
/* diag_len = */ static_cast<size_t>(1));
#else
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
namespace mlx::core { namespace mlx::core {
void inverse_impl(const array& a, array& inv) { void general_inv(array& inv, int N, int i) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void tri_inv(array& inv, int N, int i, bool upper) {
const char uplo = upper ? 'L' : 'U';
const char diag = 'N';
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: triangular inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
// Lapack uses the column-major convention. We take advantage of the following // Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see // identity to avoid transposing (see
// https://math.stackexchange.com/a/340234): // https://math.stackexchange.com/a/340234):
@@ -24,63 +121,11 @@ void inverse_impl(const array& a, array& inv) {
const int N = a.shape(-1); const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N); const size_t num_matrices = a.size() / (N * N);
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
for (int i = 0; i < num_matrices; i++) { for (int i = 0; i < num_matrices; i++) {
// Compute LU factorization. if (tri) {
sgetrf_( tri_inv(inv, N, i, upper);
/* m = */ &N, } else {
/* n = */ &N, general_inv(inv, N, i);
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch =
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
} }
} }
} }
@@ -89,7 +134,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) { if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32."); throw std::runtime_error("[Inverse::eval] only supports float32.");
} }
inverse_impl(inputs[0], output); inverse_impl(inputs[0], output, tri_, upper_);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -405,7 +405,17 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto out_strides = make_contiguous_strides<size_t>(in.shape());
copy_inplace<size_t>(
in,
out,
in.shape(),
in.strides(),
out_strides,
0,
0,
CopyType::General);
} else { } else {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
} }

View File

@@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape(); auto remaining_shape = out.shape();
remaining_shape.erase(remaining_shape.begin() + axis); remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides(); auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis); remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis]; size_t axis_stride = out.strides()[axis];
int axis_size = in.shape(axis); int axis_size = out.shape(axis);
// Perform sorting in place // Perform sorting in place
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
@@ -143,34 +143,42 @@ void argsort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis; axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis); size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape(); auto in_remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis); in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto remaining_strides = in.strides(); auto in_remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis); in_remaining_strides.erase(in_remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis]; auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
int axis_size = in.shape(axis); int axis_size = in.shape(axis);
// Perform sorting // Perform sorting
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides); size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
const T* data_ptr = in.data<T>() + loc; size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
IdxT* idx_ptr = out.data<IdxT>() + loc; const T* data_ptr = in.data<T>() + in_loc;
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
StridedIterator st_(idx_ptr, axis_stride, 0); StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size); StridedIterator ed_(idx_ptr, out_stride, axis_size);
// Initialize with iota // Initialize with iota
std::iota(st_, ed_, IdxT(0)); std::iota(st_, ed_, IdxT(0));
// Sort according to vals // Sort according to vals
StridedIterator st(idx_ptr, axis_stride, 0); StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, axis_stride, axis_size); StridedIterator ed(idx_ptr, out_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) { std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride]; auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * axis_stride]; auto v2 = data_ptr[b * in_stride];
return v1 < v2 || (v1 == v2 && a < b); return v1 < v2 || (v1 == v2 && a < b);
}); });
} }

View File

@@ -29,6 +29,15 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides()); return elem_to_loc(elem, a.shape(), a.strides());
} }
template <typename stride_t>
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<stride_t> strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i];
}
return strides;
}
// Collapse dims that are contiguous to possibly route to a better kernel // Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1}) // e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}. // should return {{2, 4}, {{1, 2}}}.

View File

@@ -18,7 +18,7 @@ function(make_jit_source SRC_FILE)
${CMAKE_C_COMPILER} ${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}
${SRC_FILE} ${SRC_FILE}
"-D${MLX_METAL_VERSION}" "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh DEPENDS make_compiled_preamble.sh
kernels/${SRC_FILE}.h kernels/${SRC_FILE}.h
${ARGN} ${ARGN}
@@ -52,6 +52,7 @@ make_jit_source(
) )
make_jit_source(scatter) make_jit_source(scatter)
make_jit_source(gather) make_jit_source(gather)
make_jit_source(hadamard)
if (MLX_METAL_JIT) if (MLX_METAL_JIT)
target_sources( target_sources(
@@ -113,6 +114,7 @@ if (MLX_METAL_JIT)
kernels/steel/conv/loaders/loader_general.h kernels/steel/conv/loaders/loader_general.h
) )
make_jit_source(quantized) make_jit_source(quantized)
make_jit_source(gemv_masked)
else() else()
target_sources( target_sources(
mlx mlx
@@ -132,6 +134,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
@@ -147,6 +150,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
) )
if (NOT MLX_METAL_PATH) if (NOT MLX_METAL_PATH)

View File

@@ -242,8 +242,17 @@ void MetalAllocator::free(Buffer buffer) {
} }
MetalAllocator& allocator() { MetalAllocator& allocator() {
static MetalAllocator allocator_; // By creating the |allocator_| on heap, the destructor of MetalAllocator will
return allocator_; // not be called on exit and all the buffers will be leaked. This is necessary
// because releasing buffers can take more than 30sec when the program holds a
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
// users when exiting.
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
// when applying this pattern to more places, or when introducing sanitizers
// to MLX.
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
static MetalAllocator* allocator_ = new MetalAllocator;
return *allocator_;
} }
size_t set_cache_limit(size_t limit) { size_t set_cache_limit(size_t limit) {

View File

@@ -21,10 +21,43 @@ namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const array& a,
bool use_2d,
int ndim) {
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << (use_2d ? "sv2" : "sv");
break;
case BinaryOpType::VectorScalar:
kname << (use_2d ? "vs2" : "vs");
break;
case BinaryOpType::VectorVector:
kname << (use_2d ? "vv2" : "vv");
break;
case BinaryOpType::General:
kname << "g";
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << ndim;
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
return kname.str();
}
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string op, const std::string& op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@@ -41,35 +74,8 @@ void binary_op_gpu_inplace(
auto& strides_b = strides[1]; auto& strides_b = strides[1];
auto& strides_out = strides[2]; auto& strides_out = strides[2];
std::string kernel_name; bool use_2d = out.data_size() > UINT32_MAX;
{ std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto kernel = auto kernel =
@@ -117,9 +123,11 @@ void binary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} else { } else {
// Launch a 1D grid of threads // Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
@@ -132,7 +140,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string op, const std::string& op,
const Stream& s) { const Stream& s) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
@@ -146,7 +154,7 @@ void binary_op_gpu(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string op) { const std::string& op) {
auto& s = outputs[0].primitive().stream(); auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s); binary_op_gpu(inputs, outputs, op, s);
} }
@@ -154,7 +162,7 @@ void binary_op_gpu(
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const std::string& op,
const Stream& s) { const Stream& s) {
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
@@ -169,35 +177,8 @@ void binary_op_gpu_inplace(
auto& strides_b = strides[1]; auto& strides_b = strides[1];
auto& strides_out = strides[2]; auto& strides_out = strides[2];
std::string kernel_name; bool use_2d = out.data_size() > UINT32_MAX;
{ std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
@@ -237,10 +218,11 @@ void binary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} else { } else {
// Launch a 1D grid of threads // Launch a 1D or 2D grid of threads
size_t nthreads =
bopt == BinaryOpType::General ? out.size() : out.data_size(); size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
@@ -253,7 +235,7 @@ void binary_op_gpu_inplace(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const std::string& op,
const Stream& s) { const Stream& s) {
assert(inputs.size() == 2); assert(inputs.size() == 2);
auto& a = inputs[0]; auto& a = inputs[0];
@@ -266,7 +248,7 @@ void binary_op_gpu(
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op) { const std::string& op) {
auto& s = out.primitive().stream(); auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s); binary_op_gpu(inputs, out, op, s);
} }

View File

@@ -9,25 +9,25 @@ namespace mlx::core {
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string op, const std::string& op,
const Stream& s); const Stream& s);
void binary_op_gpu( void binary_op_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const std::string& op,
const Stream& s); const Stream& s);
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::string op, const std::string& op,
const Stream& s); const Stream& s);
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
const std::vector<array>& inputs, const std::vector<array>& inputs,
array& out, array& out,
const std::string op, const std::string& op,
const Stream& s); const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -64,16 +64,17 @@ void copy_gpu_inplace(
auto& strides_in_ = strides[0]; auto& strides_in_ = strides[0];
auto& strides_out_ = strides[1]; auto& strides_out_ = strides[1];
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string kernel_name; std::string kernel_name;
{ {
std::ostringstream kname; std::ostringstream kname;
switch (ctype) { switch (ctype) {
case CopyType::Scalar: case CopyType::Scalar:
kname << "s"; kname << (use_2d ? "s2" : "s");
break; break;
case CopyType::Vector: case CopyType::Vector:
kname << "v"; kname << (use_2d ? "v2" : "v");
break; break;
case CopyType::General: case CopyType::General:
kname << "g"; kname << "g";
@@ -139,7 +140,8 @@ void copy_gpu_inplace(
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} else { } else {
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;

View File

@@ -14,7 +14,6 @@
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem; namespace fs = std::filesystem;
@@ -30,15 +29,29 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() { constexpr auto get_metal_version() {
#if defined METAL_3_2 #if (MLX_METAL_VERSION >= 320)
return MTL::LanguageVersion3_2; return MTL::LanguageVersion3_2;
#elif defined METAL_3_1 #elif (MLX_METAL_VERSION >= 310)
return MTL::LanguageVersion3_1; return MTL::LanguageVersion3_1;
#else #else
return MTL::LanguageVersion3_0; return MTL::LanguageVersion3_0;
#endif #endif
} }
std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
auto load_device() { auto load_device() {
auto devices = MTL::CopyAllDevices(); auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0)) auto device = static_cast<MTL::Device*>(devices->object(0))
@@ -126,6 +139,49 @@ MTL::Library* load_library(
} // namespace } // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
CommandEncoder::~CommandEncoder() {
enc->endEncoding();
enc->release();
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void CommandEncoder::set_output_array(
array& a,
int idx,
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
void CommandEncoder::dispatchThreadgroups( void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
@@ -255,13 +311,9 @@ void Device::register_library(
} }
} }
void Device::register_library( void Device::register_library(const std::string& lib_name) {
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) { if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
std::string new_lib_path = lib_path_func(lib_name); register_library(lib_name, get_colocated_mtllib_path(lib_name));
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
library_map_.insert({lib_name, new_lib});
} }
} }
@@ -271,7 +323,7 @@ MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it != library_map_.end()) { if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second; mtl_lib = it->second;
} else { // Look for metallib alongside library } else { // Look for metallib alongside library
register_library(lib_name); register_library(lib_name, get_colocated_mtllib_path(lib_name));
mtl_lib = library_map_[lib_name]; mtl_lib = library_map_[lib_name];
} }

View File

@@ -9,38 +9,16 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <dlfcn.h>
#include <filesystem>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/device.h" #include "mlx/device.h"
namespace fs = std::filesystem;
namespace mlx::core::metal { namespace mlx::core::metal {
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
using MTLFCList = using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>; std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct CommandEncoder { struct CommandEncoder {
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) { CommandEncoder(MTL::CommandBuffer* cbuf);
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
};
CommandEncoder(const CommandEncoder&) = delete; CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -63,34 +41,8 @@ struct CommandEncoder {
return enc; return enc;
} }
void set_input_array(const array& a, int idx, int64_t offset = 0) { void set_input_array(const array& a, int idx, int64_t offset = 0);
auto r_buf = void set_output_array(array& a, int idx, int64_t offset = 0);
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void set_output_array(array& a, int idx, int64_t offset = 0) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims); void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
@@ -98,10 +50,7 @@ struct CommandEncoder {
return ConcurrentContext(*this); return ConcurrentContext(*this);
} }
~CommandEncoder() { ~CommandEncoder();
enc->endEncoding();
enc->release();
}
private: private:
void maybe_split(); void maybe_split();
@@ -136,10 +85,8 @@ class Device {
void register_library( void register_library(
const std::string& lib_name, const std::string& lib_name,
const std::string& lib_path); const std::string& lib_path);
void register_library(
const std::string& lib_name, void register_library(const std::string& lib_name);
const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path);
MTL::Library* get_library(const std::string& name); MTL::Library* get_library(const std::string& name);

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc. // Copyright © 2024 Apple Inc.
#include <cassert> #include <cassert>
#include <complex> #include <complex>
#include <map> #include <map>
@@ -12,8 +12,7 @@
#include "mlx/backend/metal/slicing.h" #include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h" #include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/mlx.h" #include "mlx/utils.h"
#include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@@ -785,10 +784,9 @@ void nd_fft_op(
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
} }
std::vector<array> copies = {temp1, temp2};
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler( d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); }); [temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
} }
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) { void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@@ -0,0 +1,203 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
std::string gen_hadamard_codelet(int m) {
// Generate a O(m^2) hadamard codelet for a given M
// using the hadamard matrices above
//
// e.g. m = 2
// METAL_FUNC void hadamard_m(thread float *x) {
// float tmp[2];
// tmp[0] = + x[0] + x[1];
// tmp[1] = + x[0] - x[1];
// for (int i = 0; i < 2; i++) { x[i] = tmp[i]; }
// }
//
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
std::ostringstream source;
source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl;
if (m == 1) {
source << "}" << std::endl;
return source.str();
}
source << " float tmp[" << m << "];" << std::endl;
auto start = 1;
auto end = matrix.find('\n', start);
int index = 0;
while (end != std::string_view::npos) {
source << " tmp[" << index << "] = ";
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
source << " " << row[i] << " x[" << i << "]";
}
source << ";" << std::endl;
start = end + 1;
end = matrix.find('\n', start);
index++;
}
source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }"
<< std::endl;
source << "}" << std::endl;
return source.str();
}
void launch_hadamard(
const array& in,
array& out,
int batch_size,
int threads_per,
const std::string kernel_name,
float scale,
const Stream& s) {
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& in = inputs[0];
std::vector<array> copies;
// Only support the last axis for now
int axis = in.ndim() - 1;
auto check_input = [&copies, &s](const array& x) {
// TODO(alexbarron) pass strides to kernel to relax this constraint
bool no_copy = x.flags().row_contiguous;
if (no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
const array& in_contiguous = check_input(in);
if (in_contiguous.is_donatable()) {
out.move_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto [n, m] = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
}
int max_radix = std::min(n, 16);
// Use read_width 2 for m = 28 to avoid register spilling
int read_width = (n == 2 || m == 28) ? 2 : 4;
std::ostringstream kname;
kname << "hadamard_" << n * m << "_" << type_to_name(out);
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
kernel_source << get_template_definition(
"n" + kernel_name,
"hadamard_n",
get_type_string(in.dtype()),
n,
max_radix,
read_width);
kernel_source << get_template_definition(
"m" + kernel_name,
"hadamard_m",
get_type_string(in.dtype()),
n,
m,
read_width);
lib = d.get_library(lib_name, kernel_source.str());
}
int batch_size = in.size() / n;
int threads_per = n / max_radix;
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
//
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
//
// y = h48 @ x
//
// Upload 1:
// tmp = a.reshape(12, 4) @ h4
//
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(
in_contiguous,
temp,
batch_size,
threads_per,
"n" + kernel_name,
1.0,
s);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
} else {
launch_hadamard(
in_contiguous,
out,
batch_size,
threads_per,
"n" + kernel_name,
scale_,
s);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

View File

@@ -0,0 +1,25 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view gemv_masked_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
const device {itype}* mat [[buffer(0)]],
const device {itype}* in_vec [[buffer(1)]],
device {itype}* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device {outm_t}* out_mask [[buffer(20)]],
const device {opm_t}* mat_mask [[buffer(21)]],
const device {opm_t}* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";

View File

@@ -18,6 +18,7 @@ const char* binary();
const char* binary_two(); const char* binary_two();
const char* copy(); const char* copy();
const char* fft(); const char* fft();
const char* hadamard();
const char* quantized(); const char* quantized();
const char* ternary(); const char* ternary();
const char* scan(); const char* scan();
@@ -32,5 +33,6 @@ const char* steel_gemm_splitk();
const char* conv(); const char* conv();
const char* steel_conv(); const char* steel_conv();
const char* steel_conv_general(); const char* steel_conv_general();
const char* gemv_masked();
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@@ -1,81 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view block_sort_kernels = R"(
template [[host_name("carg_{0}")]] [[kernel]] void
block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("ncarg_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("c_{0}")]] [[kernel]] void
block_sort<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("nc_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view multiblock_sort_kernels = R"(
template [[host_name("sort_{0}")]] [[kernel]] void
mb_block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {1}* out_vals [[buffer(1)]],
device {2}* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("partition_{0}")]] [[kernel]] void
mb_block_partition<{1}, {2}, true, {3}, {4}>(
device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals [[buffer(1)]],
const device {2}* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]);
template [[host_name("merge_{0}")]] [[kernel]] void
mb_block_merge<{1}, {2}, true, {3}, {4}>(
const device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals_in [[buffer(1)]],
const device {2}* dev_idxs_in [[buffer(2)]],
device {1}* dev_vals_out [[buffer(3)]],
device {2}* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";

View File

@@ -4,11 +4,11 @@
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/copy.h" #include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h" #include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h" #include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h" #include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/steel_conv.h" #include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h" #include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
@@ -51,10 +51,12 @@ MTL::ComputePipelineState* get_unary_kernel(
std::ostringstream kernel_source; std::ostringstream kernel_source;
auto u_def = get_template_definition( auto u_def = get_template_definition(
"v" + lib_name, "unary_v", get_type_string(out_type), op); "v" + lib_name, "unary_v", get_type_string(out_type), op);
auto u2_def = get_template_definition(
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
auto g_def = get_template_definition( auto g_def = get_template_definition(
"g" + lib_name, "unary_g", get_type_string(out_type), op); "g" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << metal::utils() << metal::unary_ops() << metal::unary() kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< u_def << g_def; << u_def << u2_def << g_def;
lib = d.get_library(lib_name, kernel_source.str()); lib = d.get_library(lib_name, kernel_source.str());
} }
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
@@ -71,6 +73,9 @@ void add_binary_kernels(
{"vs", "binary_vs"}, {"vs", "binary_vs"},
{"sv", "binary_sv"}, {"sv", "binary_sv"},
{"vv", "binary_vv"}, {"vv", "binary_vv"},
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"}, {"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"}, {"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"}, {"g3", "binary_g_nd3"},
@@ -147,6 +152,7 @@ MTL::ComputePipelineState* get_ternary_kernel(
std::ostringstream kernel_source; std::ostringstream kernel_source;
const std::map<std::string, std::string> kernel_types = { const std::map<std::string, std::string> kernel_types = {
{"v", "ternary_v"}, {"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g", "ternary_g"}, {"g", "ternary_g"},
{"g1", "ternary_g_nd1"}, {"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"}, {"g2", "ternary_g_nd2"},
@@ -251,14 +257,29 @@ MTL::ComputePipelineState* get_sort_kernel(
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name);
if (lib == nullptr) { if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort() auto in_type = get_type_string(in.dtype());
<< fmt::format( auto out_type = get_type_string(out.dtype());
block_sort_kernels, kernel_source << metal::utils() << metal::sort();
lib_name, for (bool is_argsort : {true, false}) {
get_type_string(in.dtype()), std::string bool_string = is_argsort ? "true" : "false";
get_type_string(out.dtype()), std::string func_string = is_argsort ? "carg_" : "c_";
bn, kernel_source << get_template_definition(
tn); func_string + lib_name,
"block_sort",
in_type,
out_type,
bool_string,
bn,
tn);
kernel_source << get_template_definition(
"n" + func_string + lib_name,
"block_sort_nc",
in_type,
out_type,
bool_string,
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str()); lib = d.get_library(lib_name, kernel_source.str());
} }
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
@@ -275,14 +296,21 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
auto lib = d.get_library(lib_name); auto lib = d.get_library(lib_name);
if (lib == nullptr) { if (lib == nullptr) {
std::ostringstream kernel_source; std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort() kernel_source << metal::utils() << metal::sort();
<< fmt::format( std::vector<std::pair<std::string, std::string>> kernel_types = {
multiblock_sort_kernels, {"sort_", "mb_block_sort"},
lib_name, {"partition_", "mb_block_partition"},
get_type_string(in.dtype()), {"merge_", "mb_block_merge"}};
get_type_string(idx.dtype()), for (auto [name, func] : kernel_types) {
bn, kernel_source << get_template_definition(
tn); name + lib_name,
func,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
"true",
bn,
tn);
}
lib = d.get_library(lib_name, kernel_source.str()); lib = d.get_library(lib_name, kernel_source.str());
} }
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
@@ -475,6 +503,49 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn,
bool contiguous) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
: "nomask_t";
auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemv_masked()
<< fmt::format(
gemv_masked_kernel,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outm_t"_a = out_mask_type,
"opm_t"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"sm"_a = sm,
"sn"_a = sn,
"tm"_a = tm,
"tn"_a = tn,
"trans"_a = transpose_mat ? "t_" : "",
"nc"_a = contiguous ? "0" : "1");
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@@ -151,6 +151,21 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
int n_channel_specialization, int n_channel_specialization,
bool small_filter); bool small_filter);
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn,
bool contiguous);
MTL::ComputePipelineState* get_steel_conv_general_kernel( MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@@ -4,11 +4,12 @@ set(
bf16_math.h bf16_math.h
complex.h complex.h
defines.h defines.h
expm1f.h
utils.h utils.h
) )
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION}) set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only -gline-tables-only
@@ -37,7 +38,6 @@ endfunction(build_kernel)
build_kernel(arg_reduce) build_kernel(arg_reduce)
build_kernel(conv steel/conv/params.h) build_kernel(conv steel/conv/params.h)
build_kernel(gemv steel/utils.h) build_kernel(gemv steel/utils.h)
build_kernel(gemv_masked steel/utils.h)
build_kernel(layer_norm) build_kernel(layer_norm)
build_kernel(random) build_kernel(random)
build_kernel(rms_norm) build_kernel(rms_norm)
@@ -120,6 +120,7 @@ build_kernel(
steel/gemm/kernels/steel_gemm_splitk steel/gemm/kernels/steel_gemm_splitk
${STEEL_HEADERS} ${STEEL_HEADERS}
) )
build_kernel(gemv_masked steel/utils.h)
endif() endif()

View File

@@ -6,7 +6,7 @@
using namespace metal; using namespace metal;
#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310) #if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t; typedef bfloat bfloat16_t;

View File

@@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \ return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
} }
#if defined METAL_3_1 || defined METAL_3_2 || (__METAL_VERSION__ >= 310) #if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x) #define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x) #define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)

View File

@@ -36,6 +36,39 @@ template <typename T, typename U, typename Op>
c[index] = Op()(a[index], b[index]); c[index] = Op()(a[index], b[index]);
} }
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[0], b[offset]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[offset], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[offset], b[offset]);
}
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1( [[kernel]] void binary_g_nd1(
device const T* a, device const T* a,

View File

@@ -14,6 +14,9 @@
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \

View File

@@ -48,6 +48,48 @@ template <typename T, typename U, typename Op>
d[index] = out[1]; d[index] = out[1];
} }
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[0], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[offset], b[0]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[offset], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1( [[kernel]] void binary_g_nd1(
device const T* a, device const T* a,

View File

@@ -12,6 +12,9 @@
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \ instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \

View File

@@ -344,12 +344,12 @@ winograd_conv_2d_weight_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize G matrix // Initialize G matrix
simdgroup_matrix<T, 8, 8> G; simdgroup_matrix<float, 8, 8> G;
G.thread_elements()[0] = WGT::wt_transform[sm][sn]; G.thread_elements()[0] = WGT::wt_transform[sm][sn];
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1]; G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
// Initialize Gt matrix // Initialize Gt matrix
simdgroup_matrix<T, 8, 8> Gt; simdgroup_matrix<float, 8, 8> Gt;
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm]; Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm]; Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
@@ -381,15 +381,15 @@ winograd_conv_2d_weight_transform(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result // Do transform and store the result
for (int c = 0; c < BC; ++c) { for (int c = 0; c < BC; ++c) {
simdgroup_matrix<T, 8, 8> g; simdgroup_matrix<float, 8, 8> g;
g.thread_elements()[0] = g.thread_elements()[0] =
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] = g.thread_elements()[1] =
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt; simdgroup_matrix<float, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = g_out.thread_elements()[0]; wt_out_0[c * O] = static_cast<T>(g_out.thread_elements()[0]);
wt_out_1[c * O] = g_out.thread_elements()[1]; wt_out_1[c * O] = static_cast<T>(g_out.thread_elements()[1]);
} }
wt_in += BC; wt_in += BC;
@@ -433,12 +433,12 @@ winograd_conv_2d_input_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize B matrix // Initialize B matrix
simdgroup_matrix<T, 8, 8> B; simdgroup_matrix<float, 8, 8> B;
B.thread_elements()[0] = WGT::in_transform[sm][sn]; B.thread_elements()[0] = WGT::in_transform[sm][sn];
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1]; B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
// Initialize Bt matrix // Initialize Bt matrix
simdgroup_matrix<T, 8, 8> Bt; simdgroup_matrix<float, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::in_transform[sn][sm]; Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm]; Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
@@ -493,13 +493,13 @@ winograd_conv_2d_input_transform(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result // Do transform and store the result
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) { for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> I; simdgroup_matrix<float, 8, 8> I;
I.thread_elements()[0] = Is[sm][sn][c]; I.thread_elements()[0] = Is[sm][sn][c];
I.thread_elements()[1] = Is[sm][sn + 1][c]; I.thread_elements()[1] = Is[sm][sn + 1][c];
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B; simdgroup_matrix<float, 8, 8> I_out = (Bt * I) * B;
inp_out_0[c] = I_out.thread_elements()[0]; inp_out_0[c] = static_cast<T>(I_out.thread_elements()[0]);
inp_out_1[c] = I_out.thread_elements()[1]; inp_out_1[c] = static_cast<T>(I_out.thread_elements()[1]);
} }
inp_in += BC; inp_in += BC;
@@ -543,12 +543,12 @@ winograd_conv_2d_output_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize A matrix // Initialize A matrix
simdgroup_matrix<T, 8, 8> B; simdgroup_matrix<float, 8, 8> B;
B.thread_elements()[0] = WGT::out_transform[sm][sn]; B.thread_elements()[0] = WGT::out_transform[sm][sn];
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1]; B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
// Initialize At matrix // Initialize At matrix
simdgroup_matrix<T, 8, 8> Bt; simdgroup_matrix<float, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::out_transform[sn][sm]; Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm]; Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
@@ -597,16 +597,16 @@ winograd_conv_2d_output_transform(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result // Do transform and store the result
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) { for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> O_mat; simdgroup_matrix<float, 8, 8> O_mat;
O_mat.thread_elements()[0] = out_in_0[c]; O_mat.thread_elements()[0] = out_in_0[c];
O_mat.thread_elements()[1] = out_in_1[c]; O_mat.thread_elements()[1] = out_in_1[c];
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B)); simdgroup_matrix<float, 8, 8> O_out = (Bt * (O_mat * B));
if ((sm < M) && (sn < M)) { if ((sm < M) && (sn < M)) {
Os[sm][sn][c] = O_out.thread_elements()[0]; Os[sm][sn][c] = static_cast<T>(O_out.thread_elements()[0]);
} }
if ((sm < M) && ((sn + 1) < M)) { if ((sm < M) && ((sn + 1) < M)) {
Os[sm][sn + 1][c] = O_out.thread_elements()[1]; Os[sm][sn + 1][c] = static_cast<T>(O_out.thread_elements()[1]);
} }
} }
@@ -650,4 +650,5 @@ winograd_conv_2d_output_transform(
// clang-format off // clang-format off
instantiate_winograd_conv_2d(float32, float); instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
instantiate_winograd_conv_2d(float16, half); // clang-format on instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@@ -16,6 +16,26 @@ template <typename T, typename U>
dst[index] = static_cast<U>(src[index]); dst[index] = static_cast<U>(src[index]);
} }
template <typename T, typename U>
[[kernel]] void copy_s2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
dst[offset] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
dst[offset] = static_cast<U>(src[offset]);
}
template <typename T, typename U> template <typename T, typename U>
[[kernel]] void copy_g_nd1( [[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],

View File

@@ -5,95 +5,23 @@
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h" #include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name("g" #dims "_" name)]] [[kernel]] void \
copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg1_" name )]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("gg2_" name)]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("gg3_" name)]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \
template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_all(tname, itype, otype) \ #define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("s_copy" #tname, itype, otype, s) \ instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_copy("v_copy" #tname, itype, otype, v) \ instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_copy_g("copy" #tname, itype, otype) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_copy_g_nd("copy" #tname, itype, otype) instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
#define instantiate_copy_itype(itname, itype) \ #define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \ instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -83,6 +83,7 @@ float expm1f(float a) {
r = expm1f_scaled_unchecked(a, 1.0f); r = expm1f_scaled_unchecked(a, 1.0f);
/* handle severe overflow and underflow */ /* handle severe overflow and underflow */
if (abs(a - 1.0f) > 88.0f) { if (abs(a - 1.0f) > 88.0f) {
r = pow(2, a);
r = fma(r, r, -1.0f); r = fma(r, r, -1.0f);
} }
return r; return r;

View File

@@ -0,0 +1,819 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
#define MLX_MTL_CONST static constant constexpr const
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
struct _NoMask {
char x;
constexpr METAL_FUNC operator bool() {
return true;
}
constexpr METAL_FUNC operator bool() const threadgroup {
return true;
}
constexpr METAL_FUNC operator bool() const device {
return true;
}
constexpr METAL_FUNC operator bool() const constant {
return true;
}
};
typedef struct _NoMask nomask_t;
template <typename OutT, typename InT = OutT>
struct ScaleOp {
OutT scale;
METAL_FUNC OutT apply(InT x) const {
return static_cast<OutT>(x) * scale;
}
};
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
struct GEMVKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
MLX_MTL_CONST int blockM = threadsM * TM;
MLX_MTL_CONST int blockN = threadsN * TN;
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
static_assert(
SN == 8 || SN == 16 || SN == 32,
"gemv block must have a width of 8, 16, or 32");
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
MLX_MTL_CONST bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
MLX_MTL_CONST bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for
// the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated blockM outputs
//
// Edge case handling:
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
static METAL_FUNC void
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
}
}
static METAL_FUNC void load_safe(
const device T* src,
thread T dst[TN],
const int src_offset = 0,
const int src_size = TN) {
if (src_offset + TN <= src_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
}
} else { // Edgecase
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
}
}
}
static METAL_FUNC void run(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& matrix_ld [[buffer(6)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Thread local accumulation results
thread T result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
int bm = (simdM + thrM) * TM;
int bn = (simdN + thrN) * TN;
// Block position
int out_row = tid.x * blockM + bm;
// Exit simdgroup if rows out of bound
if (out_row >= out_vec_size)
return;
// Adjust tail simdgroup to ensure in bound reads
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
// Prepare mask offsets
const constant int* out_mask_strides = mask_strides;
const constant int* mat_mask_strides =
mask_strides + (has_output_mask ? 2 : 0);
const constant int* vec_mask_strides =
mat_mask_strides + (has_operand_mask ? 2 : 0);
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
const int out_mask_offset =
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
int mat_mask_offset =
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
int vec_mask_offset = 0;
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
T out_scale{1};
// Check output mask
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
// Write zeros and return if mask is 0
if (!mask_out) {
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = T(0.);
}
}
return;
}
// Store scalar if multiplicative mask
if (has_mul_output_mask) {
out_scale = T(mask_out);
}
}
// Advance matrix
mat += out_row * matrix_ld;
// Prepare for loop
constexpr const uniform<int> loop_stride = make_uniform(blockN);
const uniform<int> in_size = make_uniform(in_vec_size);
const uniform<int> n_iter = in_size / loop_stride;
const uniform<int> last_iter = loop_stride * n_iter;
const uniform<int> leftover = in_size - last_iter;
// Loop over in_vec in blocks of blockN
for (int i = 0; i < n_iter; ++i) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_unsafe(in_vec, v_coeff, bn);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
int mat_offset = 0;
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_unsafe(mat, inter, mat_offset + bn);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
mat_offset += matrix_ld;
}
}
bn += blockN;
mat_mask_offset += mat_mask_step;
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_safe(in_vec, v_coeff, bn, in_size);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Apply out scale
if (has_mul_output_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] *= out_scale;
}
}
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
MLX_MTL_PRAGMA_UNROLL
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
result[tm] += simd_shuffle_down(result[tm], sn);
}
}
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
if (thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
tgp_results[tm] = result[tm];
}
threadgroup_barrier(mem_flags::mem_none);
if (sgN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int sgn = 1; sgn < BN; sgn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
}
}
}
}
}
// Write outputs
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
struct GEMVTKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
MLX_MTL_CONST int blockM = threadsM * TM;
MLX_MTL_CONST int blockN = threadsN * TN;
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
MLX_MTL_CONST bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
MLX_MTL_CONST bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then accumulates its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
static METAL_FUNC void run(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Thread local accumulation results
T result[TN] = {0};
T inter[TN];
T v_coeff[TM];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
const int simdM = SM * sgM;
const int simdN = SN * sgN;
int cm = (simdM + thrM);
int cn = (simdN + thrN);
int bm = cm * TM;
int bn = cn * TN;
int out_col = tid.x * blockN + bn;
// Prepare mask offsets
const constant int* out_mask_strides = mask_strides;
const constant int* mat_mask_strides =
out_mask_strides + (has_output_mask ? 2 : 0);
const constant int* vec_mask_strides =
mat_mask_strides + (has_operand_mask ? 2 : 0);
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
const int out_mask_offset =
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
int mat_mask_offset =
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
int vec_mask_offset = 0;
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
T out_scale{1};
// Check output mask
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
// Write zeros and return if mask is 0
if (!mask_out) {
if (cm == 0 && out_col < out_vec_size) {
if (out_col + TN <= out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
out_vec[out_col + tn] = T(0.);
}
} else {
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
out_vec[out_col + tn] = T(0.);
}
}
}
return;
}
// Store scalar if multiplicative mask
if (has_mul_output_mask) {
out_scale = T(mask_out);
}
}
// Prepare for loop
constexpr const uniform<int> loop_stride = make_uniform(blockM);
const uniform<int> in_size = make_uniform(in_vec_size);
const uniform<int> n_iter = in_size / loop_stride;
const uniform<int> last_iter = loop_stride * n_iter;
const uniform<int> leftover = in_size - last_iter;
// Edgecase handling
if (out_col < out_vec_size) {
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
for (int i = 0; i < n_iter; ++i) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] *= block_scale;
}
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
bm += blockM;
mat_mask_offset += mat_mask_step;
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
if (has_mul_operand_mask) {
v_coeff[tm] *= block_scale;
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
// Apply out scale
if (has_mul_output_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] *= out_scale;
}
}
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
result[tn] += simd_shuffle_down(result[tn], SN * sm);
}
}
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
if (thrM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
tgp_results[tn] = result[tn];
}
threadgroup_barrier(mem_flags::mem_none);
if (sgM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int sgm = 1; sgm < BM; sgm++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
}
}
}
}
}
// Threadgroup accumulation and writing out results
if (cm == 0 && out_col < out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch> /* Batch ndim > 1 */
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
// Update batch offsets
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if (has_output_mask) {
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
mat_mask += batch_offsets.x;
vec_mask += batch_offsets.y;
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
mat_mask += tid.z * mask_batch_strides[0];
vec_mask += tid.z * mask_batch_strides[batch_ndim];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
out_mask,
mat_mask,
vec_mask,
mask_strides,
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch> /* Batch ndim > 1 */
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
// Update batch offsets
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if (has_output_mask) {
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
mat_mask += batch_offsets.x;
vec_mask += batch_offsets.y;
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
mat_mask += tid.z * mask_batch_strides[0];
vec_mask += tid.z * mask_batch_strides[batch_ndim];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
out_mask,
mat_mask,
vec_mask,
mask_strides,
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
// clang-format off
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
@@ -7,726 +8,7 @@
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/gemv_masked.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
#define MLX_MTL_CONST static constant constexpr const
struct _NoMask {
char x;
constexpr METAL_FUNC operator bool() {
return true;
}
constexpr METAL_FUNC operator bool() const threadgroup {
return true;
}
constexpr METAL_FUNC operator bool() const device {
return true;
}
constexpr METAL_FUNC operator bool() const constant {
return true;
}
};
typedef struct _NoMask nomask_t;
template <typename OutT, typename InT = OutT>
struct ScaleOp {
OutT scale;
METAL_FUNC OutT apply(InT x) const {
return static_cast<OutT>(x) * scale;
}
};
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
struct GEMVKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
MLX_MTL_CONST int blockM = threadsM * TM;
MLX_MTL_CONST int blockN = threadsN * TN;
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
static_assert(
SN == 8 || SN == 16 || SN == 32,
"gemv block must have a width of 8, 16, or 32");
static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM");
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
MLX_MTL_CONST bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
MLX_MTL_CONST bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
// - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for
// the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated blockM outputs
//
// Edge case handling:
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
static METAL_FUNC void
load_unsafe(const device T* src, thread T dst[TN], const int src_offset = 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
}
}
static METAL_FUNC void load_safe(
const device T* src,
thread T dst[TN],
const int src_offset = 0,
const int src_size = TN) {
if (src_offset + TN <= src_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src[src_offset + tn];
}
} else { // Edgecase
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
dst[tn] = src_offset + tn < src_size ? src[src_offset + tn] : 0;
}
}
}
static METAL_FUNC void run(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& matrix_ld [[buffer(6)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Thread local accumulation results
thread T result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
int bm = (simdM + thrM) * TM;
int bn = (simdN + thrN) * TN;
// Block position
int out_row = tid.x * blockM + bm;
// Exit simdgroup if rows out of bound
if (out_row >= out_vec_size)
return;
// Adjust tail simdgroup to ensure in bound reads
out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
// Prepare mask offsets
const constant int* out_mask_strides = mask_strides;
const constant int* mat_mask_strides =
mask_strides + (has_output_mask ? 2 : 0);
const constant int* vec_mask_strides =
mat_mask_strides + (has_operand_mask ? 2 : 0);
const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x);
const int out_mask_offset =
!has_output_mask ? 0 : m_block_idx * out_mask_strides[1];
int mat_mask_offset =
!has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1];
int vec_mask_offset = 0;
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0];
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1];
T out_scale{1};
// Check output mask
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
// Write zeros and return if mask is 0
if (!mask_out) {
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = T(0.);
}
}
return;
}
// Store scalar if multiplicative mask
if (has_mul_output_mask) {
out_scale = T(mask_out);
}
}
// Advance matrix
mat += out_row * matrix_ld;
// Prepare for loop
constexpr const uniform<int> loop_stride = make_uniform(blockN);
const uniform<int> in_size = make_uniform(in_vec_size);
const uniform<int> n_iter = in_size / loop_stride;
const uniform<int> last_iter = loop_stride * n_iter;
const uniform<int> leftover = in_size - last_iter;
// Loop over in_vec in blocks of blockN
for (int i = 0; i < n_iter; ++i) {
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_unsafe(in_vec, v_coeff, bn);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
int mat_offset = 0;
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_unsafe(mat, inter, mat_offset + bn);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
mat_offset += matrix_ld;
}
}
bn += blockN;
mat_mask_offset += mat_mask_step;
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
load_safe(in_vec, v_coeff, bn, in_size);
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] *= block_scale;
}
}
// Per thread work loop
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
// Accumulate results
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Apply out scale
if (has_mul_output_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] *= out_scale;
}
}
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
MLX_MTL_PRAGMA_UNROLL
for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
result[tm] += simd_shuffle_down(result[tm], sn);
}
}
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
if (thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
tgp_results[tm] = result[tm];
}
threadgroup_barrier(mem_flags::mem_none);
if (sgN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int sgn = 1; sgn < BN; sgn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] += tgp_results[sgn * (blockM + TM) + tm];
}
}
}
}
}
// Write outputs
if (simdN == 0 && thrN == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
out_vec[out_row + tm] = result[tm];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
struct GEMVTKernel {
MLX_MTL_CONST int threadsM = BM * SM;
MLX_MTL_CONST int threadsN = BN * SN;
MLX_MTL_CONST int blockM = threadsM * TM;
MLX_MTL_CONST int blockN = threadsN * TN;
static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
MLX_MTL_CONST bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
MLX_MTL_CONST bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
MLX_MTL_CONST bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (blockM, blockN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each threadgroup has (threadsN, threadsM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then accumulates its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
static METAL_FUNC void run(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
threadgroup T* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
(void)lid;
// Thread local accumulation results
T result[TN] = {0};
T inter[TN];
T v_coeff[TM];
const int thrM = SN != 32 ? simd_lid / SN : 0;
const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
const int sgN = BN != 1 ? (simd_gid % BN) : 0;
const int simdM = SM * sgM;
const int simdN = SN * sgN;
int cm = (simdM + thrM);
int cn = (simdN + thrN);
int bm = cm * TM;
int bn = cn * TN;
int out_col = tid.x * blockN + bn;
// Prepare mask offsets
const constant int* out_mask_strides = mask_strides;
const constant int* mat_mask_strides =
out_mask_strides + (has_output_mask ? 2 : 0);
const constant int* vec_mask_strides =
mat_mask_strides + (has_operand_mask ? 2 : 0);
const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x);
const int out_mask_offset =
!has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0];
int mat_mask_offset =
!has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0];
int vec_mask_offset = 0;
const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1];
const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0];
T out_scale{1};
// Check output mask
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
// Write zeros and return if mask is 0
if (!mask_out) {
if (cm == 0 && out_col < out_vec_size) {
if (out_col + TN <= out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
out_vec[out_col + tn] = T(0.);
}
} else {
for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) {
out_vec[out_col + tn] = T(0.);
}
}
}
return;
}
// Store scalar if multiplicative mask
if (has_mul_output_mask) {
out_scale = T(mask_out);
}
}
// Prepare for loop
constexpr const uniform<int> loop_stride = make_uniform(blockM);
const uniform<int> in_size = make_uniform(in_vec_size);
const uniform<int> n_iter = in_size / loop_stride;
const uniform<int> last_iter = loop_stride * n_iter;
const uniform<int> leftover = in_size - last_iter;
// Edgecase handling
if (out_col < out_vec_size) {
out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
for (int i = 0; i < n_iter; ++i) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if (!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset]))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
// Apply scale
if (has_mul_operand_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] *= block_scale;
}
}
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
bm += blockM;
mat_mask_offset += mat_mask_step;
vec_mask_offset += vec_mask_step;
}
if (leftover > 0 &&
(!has_operand_mask ||
(bool(mat_mask[mat_mask_offset]) &&
bool(vec_mask[vec_mask_offset])))) {
T block_scale{1};
if (has_mul_operand_mask) {
block_scale =
T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]);
}
for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
if (has_mul_operand_mask) {
v_coeff[tm] *= block_scale;
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
// Apply out scale
if (has_mul_output_mask) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] *= out_scale;
}
}
// Simdgroup accumulations
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
result[tn] += simd_shuffle_down(result[tn], SN * sm);
}
}
// Threadgroup accumulation results
if (needs_tgp_reduction) {
threadgroup T* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
if (thrM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
tgp_results[tn] = result[tn];
}
threadgroup_barrier(mem_flags::mem_none);
if (sgM == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int sgm = 1; sgm < BM; sgm++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tn] += tgp_results[sgm * (blockN + TN) + tn];
}
}
}
}
}
// Threadgroup accumulation and writing out results
if (cm == 0 && out_col < out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
out_vec[out_col + j] = result[j];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
/// Matrix vector multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch> /* Batch ndim > 1 */
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
// Update batch offsets
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if (has_output_mask) {
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
mat_mask += batch_offsets.x;
vec_mask += batch_offsets.y;
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
mat_mask += tid.z * mask_batch_strides[0];
vec_mask += tid.z * mask_batch_strides[batch_ndim];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
out_mask,
mat_mask,
vec_mask,
mask_strides,
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_helper( \ #define instantiate_gemv_helper( \
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
@@ -754,7 +36,6 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ #define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
@@ -763,125 +44,23 @@ template <
instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on instantiate_gemv_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc)
// clang-format off #define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \ instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, 1)
// clang-format off
#define instantiate_gemv_blocks(name, itype) \ #define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 2, 1, 4, 8, 1, 4) \ instantiate_gemv(name, itype, 2, 1, 4, 8, 1, 4) \
instantiate_gemv(name, itype, 2, 1, 4, 8, 4, 4) \ instantiate_gemv(name, itype, 2, 1, 4, 8, 4, 4) \
instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \ instantiate_gemv(name, itype, 2, 1, 2, 16, 1, 4) \
instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \ instantiate_gemv(name, itype, 2, 1, 2, 16, 4, 4) \
instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4) // clang-format on instantiate_gemv(name, itype, 4, 1, 2, 16, 4, 4)
instantiate_gemv_blocks(float32, float); instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half); instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t); instantiate_gemv_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename out_mask_t,
typename op_mask_t,
const int BM, /* Threadgroup rows (in simdgroups) */
const int BN, /* Threadgroup cols (in simdgroups) */
const int SM, /* Simdgroup rows (in threads) */
const int SN, /* Simdgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch> /* Batch ndim > 1 */
[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel =
GEMVTKernel<T, out_mask_t, op_mask_t, BM, BN, SM, SN, TM, TN>;
threadgroup T tgp_memory
[gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
// Update batch offsets
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if (has_output_mask) {
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim);
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
mat_mask += batch_offsets.x;
vec_mask += batch_offsets.y;
}
} else {
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += batch_ndim;
}
if (has_operand_mask) {
mat_mask += tid.z * mask_batch_strides[0];
vec_mask += tid.z * mask_batch_strides[batch_ndim];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
out_mask,
mat_mask,
vec_mask,
mask_strides,
gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_t_helper( \ #define instantiate_gemv_t_helper( \
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
@@ -908,7 +87,6 @@ template <
uint simd_gid [[simdgroup_index_in_threadgroup]], \ uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]); uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ #define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(name, itype, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
@@ -917,23 +95,20 @@ template <
instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(nomask, nomask_t, name, itype, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(nomask, nomask_t, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(bool_, bool, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc) // clang-format on instantiate_gemv_t_helper(name, itype, nomask, nomask_t, name, itype, bm, bn, sm, sn, tm, tn, nc)
// clang-format off
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \ #define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \ instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 0) \
instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1) // clang-format on instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, 1)
// clang-format off
#define instantiate_gemv_t_blocks(name, itype) \ #define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 4, 1) \ instantiate_gemv_t(name, itype, 1, 1, 8, 4, 4, 1) \
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \ instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 1) \ instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 1) \
instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 4) \ instantiate_gemv_t(name, itype, 1, 1, 8, 4, 8, 4) \
instantiate_gemv_t(name, itype, 1, 2, 8, 4, 8, 4) \ instantiate_gemv_t(name, itype, 1, 2, 8, 4, 8, 4) \
instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4) // clang-format on instantiate_gemv_t(name, itype, 1, 4, 8, 4, 8, 4)
// clang-format off
instantiate_gemv_t_blocks(float32, float); instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half); instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -0,0 +1,167 @@
// Copyright © 2024 Apple Inc.
#include <metal_common>
#include <metal_compute>
#include "mlx/backend/metal/kernels/steel/defines.h"
using namespace metal;
// Thread local Hadamard transform for 2^R
template <short R>
METAL_FUNC void radix_func(thread float* x) {
constexpr short logR = __builtin_ctz(R);
short h = 1;
STEEL_PRAGMA_UNROLL
for (short s = 0; s < logR; s++) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < R / 2; i++) {
short k = i & (h - 1);
short j = ((i - k) << 1) + k;
float a = x[j];
float b = x[j + h];
x[j] = a + b;
x[j + h] = a - b;
}
h <<= 1;
}
}
template <typename T, int N, int max_radix, int read_width>
[[kernel]] void hadamard_n(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const float& scale,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute a Hadamard transform of size N = 2^k
//
// Equivalent to:
// from scipy.linalg import hadamard
// y = hadamard(len(x)) @ x
constexpr short num_threads = N / max_radix;
constexpr short logN = __builtin_ctz(N);
constexpr short logR = __builtin_ctz(max_radix);
constexpr short num_steps = logN / logR;
constexpr short logFinal = logN % logR;
constexpr short final_radix = 1 << (logFinal);
int batch_idx = elem.x * N;
short i = elem.y;
threadgroup T buf[N];
// Read values from device
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float x[max_radix];
short h = 1;
STEEL_PRAGMA_UNROLL
for (short s = 0; s < num_steps; s++) {
short k = i & (h - 1);
short j = ((i - k) << logR) + k;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < max_radix; r++) {
x[r] = buf[j + h * r];
}
radix_func<max_radix>(x);
STEEL_PRAGMA_UNROLL
for (short r = 0; r < max_radix; r++) {
buf[j + h * r] = T(x[r]);
}
h <<= logR;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Do the final radix
// e.g. max_radix = 16
// N = 1024 = 16 * 16 * 4
if (final_radix > 1) {
// Each thread does multiple butterflies
STEEL_PRAGMA_UNROLL
for (int t = 0; t < max_radix / final_radix; t++) {
short index = i + t * num_threads;
short k = index & (h - 1);
short j = ((index - k) << logFinal) + k;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < final_radix; r++) {
x[r] = buf[j + h * r];
}
radix_func<final_radix>(x);
STEEL_PRAGMA_UNROLL
for (short r = 0; r < final_radix; r++) {
buf[j + h * r] = T(x[r]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Write values to device
STEEL_PRAGMA_UNROLL
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = T(buf[index + r] * scale);
}
}
}
template <typename T, int N, int M, int read_width>
[[kernel]] void hadamard_m(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const float& scale,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Compute a Hadamard transform of size M
// using a naive O(M^2) codelet.
//
// This kernel is the second stage in the computation
// of a Hadamard transform of size M*N where N = 2^k.
int index = elem.x * grid.y + elem.y;
short i = index % (N / read_width);
int batch_idx = index / (N / read_width) * M * N;
float x[read_width][M];
STEEL_PRAGMA_UNROLL
for (short c = 0; c < M; c++) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
x[r][c] = in[batch_idx + c * N + i * read_width + r];
}
}
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
// This function is JIT compiled for M
// using the Hadamard matrix strings in `metal/hadamard.cpp`
hadamard_radix_m(x[r]);
}
// Write back to device
STEEL_PRAGMA_UNROLL
for (short c = 0; c < M; c++) {
STEEL_PRAGMA_UNROLL
for (short r = 0; r < read_width; r++) {
out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);
}
}
}

View File

@@ -34,7 +34,7 @@ template <typename T, int N_READS = RMS_N_READS>
threadgroup float local_mean[1]; threadgroup float local_mean[1];
threadgroup float local_normalizer[1]; threadgroup float local_normalizer[1];
x += gid * axis_size + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
b += b_stride * lid * N_READS; b += b_stride * lid * N_READS;
@@ -89,7 +89,7 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer = local_normalizer[0]; float normalizer = local_normalizer[0];
// Write the outputs // Write the outputs
out += gid * axis_size + lid * N_READS; out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer; thread_x[i] = (thread_x[i] - mean) * normalizer;
@@ -131,7 +131,7 @@ template <typename T, int N_READS = RMS_N_READS>
threadgroup float local_mean[1]; threadgroup float local_mean[1];
threadgroup float local_normalizer[1]; threadgroup float local_normalizer[1];
x += gid * axis_size + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
b += b_stride * lid * N_READS; b += b_stride * lid * N_READS;
@@ -188,7 +188,7 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer = local_normalizer[0]; float normalizer = local_normalizer[0];
// Write the outputs // Write the outputs
out += gid * axis_size + lid * N_READS; out += gid * size_t(axis_size) + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) { for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) { if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -223,8 +223,8 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers // Advance the input pointers
x += gid * axis_size + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
g += gid * axis_size + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
// Allocate registers for the computation and accumulators // Allocate registers for the computation and accumulators
@@ -321,8 +321,8 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer2 = normalizer * normalizer; float normalizer2 = normalizer * normalizer;
// Write the outputs // Write the outputs
gx += gid * axis_size + lid * N_READS; gx += gid * size_t(axis_size) + lid * N_READS;
gw += gid * axis_size + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer; thread_x[i] = (thread_x[i] - mean) * normalizer;
@@ -360,8 +360,8 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers // Advance the input pointers
x += gid * axis_size + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
g += gid * axis_size + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
// Allocate registers for the accumulators // Allocate registers for the accumulators
@@ -457,8 +457,8 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer2 = normalizer * normalizer; float normalizer2 = normalizer * normalizer;
// Write the outputs // Write the outputs
gx += gid * axis_size + lid * N_READS; gx += gid * size_t(axis_size) + lid * N_READS;
gw += gid * axis_size + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) { for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) { if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {

View File

@@ -690,12 +690,12 @@ METAL_FUNC void qvm_impl(
template < template <
typename T, typename T,
const int BM,
const int BK,
const int BN,
const int group_size, const int group_size,
const int bits, const int bits,
const bool aligned_N> const bool aligned_N,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_t_impl( METAL_FUNC void qmm_t_impl(
const device T* x, const device T* x,
const device uint32_t* w, const device uint32_t* w,
@@ -812,11 +812,11 @@ METAL_FUNC void qmm_t_impl(
template < template <
typename T, typename T,
const int BM,
const int BK,
const int BN,
const int group_size, const int group_size,
const int bits> const int bits,
const int BM = 32,
const int BK = 32,
const int BN = 32>
METAL_FUNC void qmm_n_impl( METAL_FUNC void qmm_n_impl(
const device T* x, const device T* x,
const device uint32_t* w, const device uint32_t* w,
@@ -1099,7 +1099,7 @@ template <
threadgroup T Xs[BM * BK_padded]; threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded]; threadgroup T Ws[BN * BK_padded];
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>( qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
} }
@@ -1131,7 +1131,7 @@ template <
threadgroup T Xs[BM * BK_padded]; threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded]; threadgroup T Ws[BK * BN_padded];
qmm_n_impl<T, BM, BK, BN, group_size, bits>( qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
} }
@@ -1382,7 +1382,7 @@ template <
s_strides, s_strides,
b_strides, b_strides,
tid); tid);
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>( qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
} }
@@ -1450,6 +1450,147 @@ template <
s_strides, s_strides,
b_strides, b_strides,
tid); tid);
qmm_n_impl<T, BM, BK, BN, group_size, bits>( qmm_n_impl<T, group_size, bits, BM, BK, BN>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
} }
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize(
const device T* w [[buffer(0)]],
device uint8_t* out [[buffer(1)]],
device T* scales [[buffer(2)]],
device T* biases [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr T eps = T(1e-7);
constexpr int simd_size = 32;
constexpr int uint8_bits = 8;
constexpr T n_bins = (1 << bits) - 1;
constexpr int packs_per_int = uint8_bits / bits;
constexpr int values_per_reduce = group_size / simd_size;
constexpr int writes_per_reduce = packs_per_int / values_per_reduce;
constexpr int writes_per_pack =
writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int;
static_assert(
group_size % simd_size == 0,
"Group size must be divisible by simd size.");
int in_index = index * values_per_reduce;
int out_index = index * writes_per_pack;
T w_thread[values_per_reduce];
T w_min = Limits<T>::max;
T w_max = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
T val = w[in_index + i];
w_thread[i] = val;
w_min = min(w_min, val);
w_max = max(w_max, val);
}
w_min = simd_min(w_min);
w_max = simd_max(w_max);
T scale = max((w_max - w_min) / n_bins, eps);
bool side = abs(w_min) > abs(w_max);
scale = side ? scale : -scale;
T edge = side ? w_min : w_max;
T q0 = round(edge / scale);
bool at_zero = q0 == 0.0f;
scale = at_zero ? scale : edge / q0;
T bias = at_zero ? T(0) : edge;
// Write out the scales and biases
int gindex = in_index / group_size;
if (in_index % group_size == 0) {
scales[gindex] = scale;
biases[gindex] = bias;
}
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_reduce; i++) {
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * (i % packs_per_int));
}
if (packs_per_int < values_per_reduce &&
i % packs_per_int == packs_per_int - 1) {
out[out_index + i / packs_per_int] = output;
output = 0;
} else {
#pragma clang loop unroll(full)
for (int j = 0; j < writes_per_reduce - 1; j++) {
uint8_t sval = simd_shuffle_down(val, j + 1);
output += sval << (bits * (values_per_reduce + j + i));
}
}
}
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
out[out_index / writes_per_reduce] = output;
}
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_quantize_scales_biases(
const device T* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device uint8_t* out [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
constexpr T n_bins = (1 << bits) - 1;
int in_index = index * packs_per_int;
int gindex = in_index / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint8_t output = 0;
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t val = min(round((w[in_index + i] - bias) / scale), n_bins);
if (bits == 8) {
output = val;
} else {
output += val << (bits * i);
}
}
out[index] = output;
}
template <typename T, const int group_size, const int bits>
[[kernel]] void affine_dequantize(
const device uint8_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* biases [[buffer(2)]],
device T* out [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
constexpr int uint8_bits = 8;
constexpr int packs_per_int = uint8_bits / bits;
int oindex = index * packs_per_int;
int gindex = oindex / group_size;
T scale = scales[gindex];
T bias = biases[gindex];
uint val = w[index];
#pragma clang loop unroll(full)
for (int i = 0; i < packs_per_int; i++) {
uint8_t d;
if (bits == 2) {
d = (val >> (bits * i)) & 0x03;
} else if (bits == 4) {
d = (val >> (bits * i)) & 0x0f;
} else if (bits == 8) {
d = val;
}
out[oindex + i] = scale * d + bias;
}
}

View File

@@ -5,241 +5,67 @@
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" #include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/quantized.h" #include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_quantized(name, type, group_size, bits) \
#define instantiate_qmv_fast(itype, group_size, bits) \
instantiate_kernel( \ instantiate_kernel( \
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \ #name "_" #type "_gs_" #group_size "_b_" #bits, \
qmv_fast, \ name, \
itype, \ type, \
group_size, \ group_size, \
bits) bits)
#define instantiate_qmv_fast_types(group_size, bits) \ #define instantiate_quantized_types(name, group_size, bits) \
instantiate_qmv_fast(float, group_size, bits) \ instantiate_quantized(name, float, group_size, bits) \
instantiate_qmv_fast(float16_t, group_size, bits) \ instantiate_quantized(name, float16_t, group_size, bits) \
instantiate_qmv_fast(bfloat16_t, group_size, bits) instantiate_quantized(name, bfloat16_t, group_size, bits)
instantiate_qmv_fast_types(128, 2) #define instantiate_quantized_groups(name, bits) \
instantiate_qmv_fast_types(128, 4) instantiate_quantized_types(name, 128, bits) \
instantiate_qmv_fast_types(128, 8) instantiate_quantized_types(name, 64, bits) \
instantiate_qmv_fast_types( 64, 2) instantiate_quantized_types(name, 32, bits)
instantiate_qmv_fast_types( 64, 4)
instantiate_qmv_fast_types( 64, 8)
instantiate_qmv_fast_types( 32, 2)
instantiate_qmv_fast_types( 32, 4)
instantiate_qmv_fast_types( 32, 8)
#define instantiate_qmv(itype, group_size, bits) \ #define instantiate_quantized_all(name) \
instantiate_kernel( \ instantiate_quantized_groups(name, 2) \
"qmv_" #itype "_gs_" #group_size "_b_" #bits, \ instantiate_quantized_groups(name, 4) \
qmv, \ instantiate_quantized_groups(name, 8)
itype, \
group_size, \
bits)
#define instantiate_qmv_types(group_size, bits) \ instantiate_quantized_all(qmv_fast)
instantiate_qmv(float, group_size, bits) \ instantiate_quantized_all(qmv)
instantiate_qmv(float16_t, group_size, bits) \ instantiate_quantized_all(qvm)
instantiate_qmv(bfloat16_t, group_size, bits) instantiate_quantized_all(qmm_n)
instantiate_quantized_all(bs_qmv_fast)
instantiate_quantized_all(bs_qmv)
instantiate_quantized_all(bs_qvm)
instantiate_quantized_all(bs_qmm_n)
instantiate_quantized_all(affine_quantize)
instantiate_quantized_all(affine_quantize_scales_biases)
instantiate_quantized_all(affine_dequantize)
instantiate_qmv_types(128, 2) #define instantiate_quantized_aligned(name, type, group_size, bits, aligned) \
instantiate_qmv_types(128, 4) instantiate_kernel( \
instantiate_qmv_types(128, 8) #name "_" #type "_gs_" #group_size "_b_" #bits "_alN_" #aligned, \
instantiate_qmv_types( 64, 2) name, \
instantiate_qmv_types( 64, 4) type, \
instantiate_qmv_types( 64, 8) group_size, \
instantiate_qmv_types( 32, 2) bits, \
instantiate_qmv_types( 32, 4) aligned)
instantiate_qmv_types( 32, 8)
#define instantiate_qvm(itype, group_size, bits) \ #define instantiate_quantized_types_aligned(name, group_size, bits) \
instantiate_kernel( \ instantiate_quantized_aligned(name, float, group_size, bits, true) \
"qvm_" #itype "_gs_" #group_size "_b_" #bits, \ instantiate_quantized_aligned(name, float16_t, group_size, bits, true) \
qvm, \ instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, true) \
itype, \ instantiate_quantized_aligned(name, float, group_size, bits, false) \
group_size, \ instantiate_quantized_aligned(name, float16_t, group_size, bits, false) \
bits) instantiate_quantized_aligned(name, bfloat16_t, group_size, bits, false)
#define instantiate_qvm_types(group_size, bits) \ #define instantiate_quantized_groups_aligned(name, bits) \
instantiate_qvm(float, group_size, bits) \ instantiate_quantized_types_aligned(name, 128, bits) \
instantiate_qvm(float16_t, group_size, bits) \ instantiate_quantized_types_aligned(name, 64, bits) \
instantiate_qvm(bfloat16_t, group_size, bits) instantiate_quantized_types_aligned(name, 32, bits)
instantiate_qvm_types(128, 2) #define instantiate_quantized_all_aligned(name) \
instantiate_qvm_types(128, 4) instantiate_quantized_groups_aligned(name, 2) \
instantiate_qvm_types(128, 8) instantiate_quantized_groups_aligned(name, 4) \
instantiate_qvm_types( 64, 2) instantiate_quantized_groups_aligned(name, 8) \
instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
instantiate_qvm_types( 32, 2)
instantiate_qvm_types( 32, 4)
instantiate_qvm_types( 32, 8)
#define instantiate_qmm_t(itype, group_size, bits, aligned_N) \ instantiate_quantized_all_aligned(qmm_t)
instantiate_kernel( \ instantiate_quantized_all_aligned(bs_qmm_t) // clang-format on
"qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
qmm_t, \
itype, \
group_size, \
bits, \
aligned_N)
#define instantiate_qmm_t_types(group_size, bits) \
instantiate_qmm_t(float, group_size, bits, false) \
instantiate_qmm_t(float16_t, group_size, bits, false) \
instantiate_qmm_t(bfloat16_t, group_size, bits, false) \
instantiate_qmm_t(float, group_size, bits, true) \
instantiate_qmm_t(float16_t, group_size, bits, true) \
instantiate_qmm_t(bfloat16_t, group_size, bits, true)
instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4)
instantiate_qmm_t_types(128, 8)
instantiate_qmm_t_types( 64, 2)
instantiate_qmm_t_types( 64, 4)
instantiate_qmm_t_types( 64, 8)
instantiate_qmm_t_types( 32, 2)
instantiate_qmm_t_types( 32, 4)
instantiate_qmm_t_types( 32, 8)
#define instantiate_qmm_n(itype, group_size, bits) \
instantiate_kernel( \
"qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
qmm_n, \
itype, \
group_size, \
bits)
#define instantiate_qmm_n_types(group_size, bits) \
instantiate_qmm_n(float, group_size, bits) \
instantiate_qmm_n(float16_t, group_size, bits) \
instantiate_qmm_n(bfloat16_t, group_size, bits)
instantiate_qmm_n_types(128, 2)
instantiate_qmm_n_types(128, 4)
instantiate_qmm_n_types(128, 8)
instantiate_qmm_n_types( 64, 2)
instantiate_qmm_n_types( 64, 4)
instantiate_qmm_n_types( 64, 8)
instantiate_qmm_n_types( 32, 2)
instantiate_qmm_n_types( 32, 4)
instantiate_qmm_n_types( 32, 8)
#define instantiate_bs_qmv_fast(itype, group_size, bits) \
instantiate_kernel( \
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
bs_qmv_fast, \
itype, \
group_size, \
bits)
#define instantiate_bs_qmv_fast_types(group_size, bits) \
instantiate_bs_qmv_fast(float, group_size, bits) \
instantiate_bs_qmv_fast(float16_t, group_size, bits) \
instantiate_bs_qmv_fast(bfloat16_t, group_size, bits)
instantiate_bs_qmv_fast_types(128, 2)
instantiate_bs_qmv_fast_types(128, 4)
instantiate_bs_qmv_fast_types(128, 8)
instantiate_bs_qmv_fast_types( 64, 2)
instantiate_bs_qmv_fast_types( 64, 4)
instantiate_bs_qmv_fast_types( 64, 8)
instantiate_bs_qmv_fast_types( 32, 2)
instantiate_bs_qmv_fast_types( 32, 4)
instantiate_bs_qmv_fast_types( 32, 8)
#define instantiate_bs_qmv(itype, group_size, bits) \
instantiate_kernel( \
"bs_qmv_" #itype "_gs_" #group_size "_b_" #bits, \
bs_qmv, \
itype, \
group_size, \
bits)
#define instantiate_bs_qmv_types(group_size, bits) \
instantiate_bs_qmv(float, group_size, bits) \
instantiate_bs_qmv(float16_t, group_size, bits) \
instantiate_bs_qmv(bfloat16_t, group_size, bits)
instantiate_bs_qmv_types(128, 2)
instantiate_bs_qmv_types(128, 4)
instantiate_bs_qmv_types(128, 8)
instantiate_bs_qmv_types( 64, 2)
instantiate_bs_qmv_types( 64, 4)
instantiate_bs_qmv_types( 64, 8)
instantiate_bs_qmv_types( 32, 2)
instantiate_bs_qmv_types( 32, 4)
instantiate_bs_qmv_types( 32, 8)
#define instantiate_bs_qvm(itype, group_size, bits) \
instantiate_kernel( \
"bs_qvm_" #itype "_gs_" #group_size "_b_" #bits, \
bs_qvm, \
itype, \
group_size, \
bits)
#define instantiate_bs_qvm_types(group_size, bits) \
instantiate_bs_qvm(float, group_size, bits) \
instantiate_bs_qvm(float16_t, group_size, bits) \
instantiate_bs_qvm(bfloat16_t, group_size, bits)
instantiate_bs_qvm_types(128, 2)
instantiate_bs_qvm_types(128, 4)
instantiate_bs_qvm_types(128, 8)
instantiate_bs_qvm_types( 64, 2)
instantiate_bs_qvm_types( 64, 4)
instantiate_bs_qvm_types( 64, 8)
instantiate_bs_qvm_types( 32, 2)
instantiate_bs_qvm_types( 32, 4)
instantiate_bs_qvm_types( 32, 8)
#define instantiate_bs_qmm_t(itype, group_size, bits, aligned_N) \
instantiate_kernel( \
"bs_qmm_t_" #itype "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N, \
bs_qmm_t, \
itype, \
group_size, \
bits, \
aligned_N)
#define instantiate_bs_qmm_t_types(group_size, bits) \
instantiate_bs_qmm_t(float, group_size, bits, false) \
instantiate_bs_qmm_t(float16_t, group_size, bits, false) \
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, false) \
instantiate_bs_qmm_t(float, group_size, bits, true) \
instantiate_bs_qmm_t(float16_t, group_size, bits, true) \
instantiate_bs_qmm_t(bfloat16_t, group_size, bits, true)
instantiate_bs_qmm_t_types(128, 2)
instantiate_bs_qmm_t_types(128, 4)
instantiate_bs_qmm_t_types(128, 8)
instantiate_bs_qmm_t_types( 64, 2)
instantiate_bs_qmm_t_types( 64, 4)
instantiate_bs_qmm_t_types( 64, 8)
instantiate_bs_qmm_t_types( 32, 2)
instantiate_bs_qmm_t_types( 32, 4)
instantiate_bs_qmm_t_types( 32, 8)
#define instantiate_bs_qmm_n(itype, group_size, bits) \
instantiate_kernel( \
"bs_qmm_n_" #itype "_gs_" #group_size "_b_" #bits, \
bs_qmm_n, \
itype, \
group_size, \
bits)
#define instantiate_bs_qmm_n_types(group_size, bits) \
instantiate_bs_qmm_n(float, group_size, bits) \
instantiate_bs_qmm_n(float16_t, group_size, bits) \
instantiate_bs_qmm_n(bfloat16_t, group_size, bits)
instantiate_bs_qmm_n_types(128, 2)
instantiate_bs_qmm_n_types(128, 4)
instantiate_bs_qmm_n_types(128, 8)
instantiate_bs_qmm_n_types( 64, 2)
instantiate_bs_qmm_n_types( 64, 4)
instantiate_bs_qmm_n_types( 64, 8)
instantiate_bs_qmm_n_types( 32, 2)
instantiate_bs_qmm_n_types( 32, 4)
instantiate_bs_qmm_n_types( 32, 8) // clang-format on

View File

@@ -43,20 +43,22 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
auto half_size = grid_dim.y - odd; auto half_size = grid_dim.y - odd;
out += index.x * bytes_per_key; out += index.x * bytes_per_key;
bool drop_last = odd && (index.y == half_size); bool drop_last = odd && (index.y == half_size);
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y); auto bits = threefry2x32_hash(
auto bits = threefry2x32_hash(key, count); key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y));
size_t idx = size_t(index.y) << 2;
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
out[4 * count.x + i] = bits.bytes[0][i]; out[idx + i] = bits.bytes[0][i];
} }
if (!drop_last) { if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4); int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) { for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i]; out[idx + i] = bits.bytes[1][i];
} }
} else { } else {
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
out[4 * count.y + i] = bits.bytes[1][i]; out[idx + i] = bits.bytes[1][i];
} }
} }
} }
@@ -77,22 +79,24 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim); auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim);
auto key = uint2(keys[k1_elem], keys[k2_elem]); auto key = uint2(keys[k1_elem], keys[k2_elem]);
auto half_size = grid_dim.y - odd; auto half_size = grid_dim.y - odd;
out += index.x * bytes_per_key; out += size_t(index.x) * bytes_per_key;
bool drop_last = odd && (index.y == half_size); bool drop_last = odd && (index.y == half_size);
auto count = uint2(index.y, drop_last ? 0 : index.y + grid_dim.y); auto bits = threefry2x32_hash(
auto bits = threefry2x32_hash(key, count); key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y));
size_t idx = size_t(index.y) << 2;
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
out[4 * count.x + i] = bits.bytes[0][i]; out[idx + i] = bits.bytes[0][i];
} }
if (!drop_last) { if (!drop_last) {
idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2;
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4); int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) { for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i]; out[idx + i] = bits.bytes[1][i];
} }
} else { } else {
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
out[4 * count.y + i] = bits.bytes[1][i]; out[idx + i] = bits.bytes[1][i];
} }
} }
} }

View File

@@ -24,7 +24,7 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float acc = 0; float acc = 0;
x += gid * axis_size + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -62,7 +62,7 @@ template <typename T, int N_READS = RMS_N_READS>
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the outputs // Write the outputs
out += gid * axis_size + lid * N_READS; out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]); out[i] = w[w_stride * i] * static_cast<T>(x[i] * local_inv_mean[0]);
@@ -92,7 +92,7 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
float acc = 0; float acc = 0;
x += gid * axis_size + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) { for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) { if (r + lid * N_READS + N_READS <= axis_size) {
@@ -132,7 +132,7 @@ template <typename T, int N_READS = RMS_N_READS>
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// Write the outputs // Write the outputs
out += gid * axis_size + lid * N_READS; out += gid * size_t(axis_size) + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) { for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) { if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
@@ -165,8 +165,8 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers // Advance the input pointers
x += gid * axis_size + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
g += gid * axis_size + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
// Allocate registers for the computation and accumulators // Allocate registers for the computation and accumulators
@@ -233,8 +233,8 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer3 = normalizer * normalizer * normalizer; float normalizer3 = normalizer * normalizer * normalizer;
// Write the outputs // Write the outputs
gx += gid * axis_size + lid * N_READS; gx += gid * size_t(axis_size) + lid * N_READS;
gw += gid * axis_size + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
gx[i] = static_cast<T>( gx[i] = static_cast<T>(
@@ -270,8 +270,8 @@ template <typename T, int N_READS = RMS_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Advance the input pointers // Advance the input pointers
x += gid * axis_size + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
g += gid * axis_size + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
// Allocate registers for the accumulators // Allocate registers for the accumulators
@@ -337,8 +337,8 @@ template <typename T, int N_READS = RMS_N_READS>
float normalizer3 = normalizer * normalizer * normalizer; float normalizer3 = normalizer * normalizer * normalizer;
// Write the outputs // Write the outputs
gx += gid * axis_size + lid * N_READS; gx += gid * size_t(axis_size) + lid * N_READS;
gw += gid * axis_size + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS;
for (uint r = 0; r < axis_size; r += lsize * N_READS) { for (uint r = 0; r < axis_size; r += lsize * N_READS) {
if (r + lid * N_READS + N_READS <= axis_size) { if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {

View File

@@ -6,36 +6,17 @@
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward> template <typename T, bool traditional, bool forward>
[[kernel]] void rope( [[kernel]] void rope_single(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const int& offset, constant const int& offset,
constant const float& base, constant const float& base,
constant const float& scale, constant const float& scale,
uint3 pos [[thread_position_in_grid]], constant const size_t& stride,
uint3 grid [[threads_per_grid]]) { uint2 pos [[thread_position_in_grid]],
// Compute the input and output indices uint2 grid [[threads_per_grid]]) {
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
// Figure out L and d. // Figure out L and d.
float L = scale * static_cast<float>(pos.y + offset); float L = scale * static_cast<float>(offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x); float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta // Compute costheta, sintheta
@@ -43,6 +24,21 @@ template <typename T, bool traditional, bool forward>
float costheta = metal::fast::cos(theta); float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta); float sintheta = metal::fast::sin(theta);
// Compute the input and output indices
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x + pos.y * stride;
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x + pos.y * stride;
in_index_2 = in_index_1 + 1;
} else {
out_index_1 = pos.x + pos.y * stride;
out_index_2 = out_index_1 + grid.x;
in_index_1 = pos.x + pos.y * stride;
in_index_2 = in_index_1 + grid.x;
}
// Read and write the output // Read and write the output
float x1 = static_cast<float>(in[in_index_1]); float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]); float x2 = static_cast<float>(in[in_index_2]);
@@ -59,19 +55,97 @@ template <typename T, bool traditional, bool forward>
out[out_index_2] = static_cast<T>(rx2); out[out_index_2] = static_cast<T>(rx2);
} }
#define instantiate_rope(name, type, traditional, forward) \ template <typename T, bool traditional, bool forward, int N = 4>
template [[host_name("rope_" #name)]] [[kernel]] void \ [[kernel]] void rope(
rope<type, traditional, forward>( \ const device T* in [[buffer(0)]],
const device type* in [[buffer(0)]], \ device T* out [[buffer(1)]],
device type* out [[buffer(1)]], \ constant const int& offset,
constant const size_t strides[3], \ constant const float& base,
constant const size_t out_strides[3], \ constant const float& scale,
constant const int& offset, \ constant const size_t strides[3],
constant const float& base, \ constant const size_t out_strides[3],
constant const float& scale, \ constant const size_t& n_batch,
uint3 pos [[thread_position_in_grid]], \ uint3 pos [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Figure out L and d.
float L = scale * static_cast<float>(pos.y + offset);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
// Compute costheta, sintheta
float theta = L * metal::exp2(-d * base);
float costheta = metal::fast::cos(theta);
float sintheta = metal::fast::sin(theta);
// Compute the input and output indices
size_t in_index_1, in_index_2;
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
N * pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 =
pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
}
for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) {
// Read and write the output
float x1 = static_cast<float>(in[in_index_1]);
float x2 = static_cast<float>(in[in_index_2]);
float rx1;
float rx2;
if (forward) {
rx1 = x1 * costheta - x2 * sintheta;
rx2 = x1 * sintheta + x2 * costheta;
} else {
rx1 = x2 * sintheta + x1 * costheta;
rx2 = x2 * costheta - x1 * sintheta;
}
out[out_index_1] = static_cast<T>(rx1);
out[out_index_2] = static_cast<T>(rx2);
in_index_1 += strides[0];
in_index_2 += strides[0];
out_index_1 += out_strides[0];
out_index_2 += out_strides[0];
}
}
#define instantiate_rope_g(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const size_t& n_batch, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]); uint3 grid [[threads_per_grid]]);
#define instantiate_rope_s(name, type, traditional, forward) \
template [[host_name("rope_single_" #name)]] [[kernel]] void \
rope_single<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
constant const size_t& stride, \
uint2 pos [[thread_position_in_grid]], \
uint2 grid [[threads_per_grid]]);
#define instantiate_rope(name, type, traditional, forward) \
instantiate_rope_s(name, type, traditional, forward) \
instantiate_rope_g(name, type, traditional, forward)
// clang-format off // clang-format off
instantiate_rope(traditional_float16, half, true, true) instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true) instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
@@ -84,4 +158,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false) instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false) instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false) instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false) // clang-format on instantiate_rope(vjp_float32, float, false, false) // clang-format on

View File

@@ -25,7 +25,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
AccT ld[N_READS]; AccT ld[N_READS];
in += gid * axis_size + lid * N_READS; in += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]); ld[i] = AccT(in[i]);
@@ -83,7 +83,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
normalizer = 1 / local_normalizer[0]; normalizer = 1 / local_normalizer[0];
// Normalize and write to the output // Normalize and write to the output
out += gid * axis_size + lid * N_READS; out += gid * size_t(axis_size) + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) { if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
out[i] = T(ld[i] * normalizer); out[i] = T(ld[i] * normalizer);
@@ -107,7 +107,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint lsize [[threads_per_threadgroup]], uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size; in += gid * size_t(axis_size);
constexpr int SIMD_SIZE = 32; constexpr int SIMD_SIZE = 32;
@@ -170,7 +170,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
// Finally given the normalizer and max value we can directly write the // Finally given the normalizer and max value we can directly write the
// softmax output // softmax output
out += gid * axis_size; out += gid * size_t(axis_size);
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize)); for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) { r++) {
int offset = r * lsize * N_READS + lid * N_READS; int offset = r * lsize * N_READS + lid * N_READS;

View File

@@ -235,19 +235,21 @@ struct KernelMergeSort {
const device T* inp, const device T* inp,
device U* out, device U* out,
const constant int& size_sorted_axis, const constant int& size_sorted_axis,
const constant int& stride_sorted_axis, const constant int& in_stride_sorted_axis,
const constant int& stride_segment_axis, const constant int& out_stride_sorted_axis,
const constant int& in_stride_segment_axis,
const constant int& out_stride_segment_axis,
threadgroup val_t* tgp_vals, threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs, threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index // tid.y tells us the segment index
inp += tid.y * stride_segment_axis; inp += tid.y * in_stride_segment_axis;
out += tid.y * stride_segment_axis; out += tid.y * out_stride_segment_axis;
// Copy into threadgroup memory // Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
: val_t(CompareOp::init); : val_t(CompareOp::init);
if (ARG_SORT) { if (ARG_SORT) {
tgp_idxs[i] = i; tgp_idxs[i] = i;
@@ -264,9 +266,9 @@ struct KernelMergeSort {
// Write output // Write output
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
if (ARG_SORT) { if (ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i]; out[i * out_stride_sorted_axis] = tgp_idxs[i];
} else { } else {
out[i * stride_sorted_axis] = tgp_vals[i]; out[i * out_stride_sorted_axis] = tgp_vals[i];
} }
} }
} }
@@ -282,8 +284,10 @@ template <
const device T* inp [[buffer(0)]], const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]], const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]], const constant int& in_stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]], const constant int& out_stride_sorted_axis [[buffer(4)]],
const constant int& in_stride_segment_axis [[buffer(5)]],
const constant int& out_stride_segment_axis [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = using sort_kernel =
@@ -298,8 +302,10 @@ template <
inp, inp,
out, out,
size_sorted_axis, size_sorted_axis,
stride_sorted_axis, in_stride_sorted_axis,
stride_segment_axis, out_stride_sorted_axis,
in_stride_segment_axis,
out_stride_segment_axis,
tgp_vals, tgp_vals,
tgp_idxs, tgp_idxs,
tid, tid,
@@ -310,8 +316,10 @@ template <
inp, inp,
out, out,
size_sorted_axis, size_sorted_axis,
stride_sorted_axis, in_stride_sorted_axis,
stride_segment_axis, out_stride_sorted_axis,
in_stride_segment_axis,
out_stride_segment_axis,
tgp_vals, tgp_vals,
nullptr, nullptr,
tid, tid,
@@ -331,10 +339,12 @@ template <
const device T* inp [[buffer(0)]], const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]], const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]], const constant int& in_stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]], const constant int& out_stride_sorted_axis [[buffer(4)]],
const device int* nc_shape [[buffer(5)]], const constant int& nc_dim [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]], const device int* nc_shape [[buffer(6)]],
const device size_t* in_nc_strides [[buffer(7)]],
const device size_t* out_nc_strides [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = using sort_kernel =
@@ -342,9 +352,10 @@ template <
using val_t = typename sort_kernel::val_t; using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t; using idx_t = typename sort_kernel::idx_t;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim);
inp += block_idx; auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim);
out += block_idx; inp += in_block_idx;
out += out_block_idx;
if (ARG_SORT) { if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
@@ -353,7 +364,9 @@ template <
inp, inp,
out, out,
size_sorted_axis, size_sorted_axis,
stride_sorted_axis, in_stride_sorted_axis,
out_stride_sorted_axis,
zero_helper,
zero_helper, zero_helper,
tgp_vals, tgp_vals,
tgp_idxs, tgp_idxs,
@@ -365,7 +378,9 @@ template <
inp, inp,
out, out,
size_sorted_axis, size_sorted_axis,
stride_sorted_axis, in_stride_sorted_axis,
out_stride_sorted_axis,
zero_helper,
zero_helper, zero_helper,
tgp_vals, tgp_vals,
nullptr, nullptr,
@@ -507,13 +522,13 @@ template <
bool ARG_SORT, bool ARG_SORT,
short BLOCK_THREADS, short BLOCK_THREADS,
short N_PER_THREAD> short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void [[kernel]] void mb_block_partition(
mb_block_partition(
device idx_t* block_partitions [[buffer(0)]], device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]], const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]], const device idx_t* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]], const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]], const constant int& merge_tiles [[buffer(4)]],
const constant int& n_blocks [[buffer(5)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) { uint3 tgp_dims [[threads_per_threadgroup]]) {
@@ -528,23 +543,29 @@ mb_block_partition(
dev_vals += tid.y * size_sorted_axis; dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis; dev_idxs += tid.y * size_sorted_axis;
// Find location in merge step for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
int merge_group = lid.x / merge_tiles; // Find location in merge step
int merge_lane = lid.x % merge_tiles; int merge_group = i / merge_tiles;
int merge_lane = i % merge_tiles;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int A_st = min(size_sorted_axis, sort_st); int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed; int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition( int partition = sort_kernel::merge_partition(
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at); dev_vals + A_st,
dev_vals + B_st,
A_ed - A_st,
B_ed - B_st,
partition_at);
block_partitions[lid.x] = A_st + partition; block_partitions[i] = A_st + partition;
}
} }
template < template <

View File

@@ -10,28 +10,10 @@
#define instantiate_block_sort( \ #define instantiate_block_sort( \
name, itname, itype, otname, otype, arg_sort, bn, tn) \ name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \ instantiate_kernel("c" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
"_tn" #tn)]] [[kernel]] void \ block_sort, itype, otype, arg_sort, bn, tn) \
block_sort<itype, otype, arg_sort, bn, tn>( \ instantiate_kernel("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn, \
const device itype* inp [[buffer(0)]], \ block_sort_nc, itype, otype, arg_sort, bn, tn)
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
)]] [[kernel]] void \
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& nc_dim [[buffer(4)]], \
const device int* nc_shape [[buffer(5)]], \
const device size_t* nc_strides [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \ #define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort( \ instantiate_block_sort( \
@@ -69,43 +51,12 @@ instantiate_block_sort_long(int64, int64_t)
#define instantiate_multi_block_sort( \ #define instantiate_multi_block_sort( \
vtname, vtype, itname, itype, arg_sort, bn, tn) \ vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \ instantiate_kernel("sort_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
"_tn" #tn)]] [[kernel]] void \ mb_block_sort, vtype, itype, arg_sort, bn, tn) \
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \ instantiate_kernel("partition_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
const device vtype* inp [[buffer(0)]], \ mb_block_partition, vtype, itype, arg_sort, bn, tn) \
device vtype* out_vals [[buffer(1)]], \ instantiate_kernel("merge_mbsort_" #vtname "_" #itname "_bn" #bn "_tn" #tn, \
device itype* out_idxs [[buffer(2)]], \ mb_block_merge, vtype, itype, arg_sort, bn, tn)
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& stride_sorted_axis [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype * block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& merge_tiles [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \
device vtype* dev_vals_out [[buffer(3)]], \
device itype* dev_idxs_out [[buffer(4)]], \
const constant int& size_sorted_axis [[buffer(5)]], \
const constant int& merge_tiles [[buffer(6)]], \
const constant int& num_tiles [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_multi_block_sort_base(vtname, vtype) \ #define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8) instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)

View File

@@ -10,6 +10,18 @@ template <typename T, typename Op>
d[index] = Op()(a[index], b[index], c[index]); d[index] = Op()(a[index], b[index], c[index]);
} }
template <typename T, typename Op>
[[kernel]] void ternary_v2(
device const bool* a,
device const T* b,
device const T* c,
device T* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
d[offset] = Op()(a[offset], b[offset], c[offset]);
}
template <typename T, typename Op> template <typename T, typename Op>
[[kernel]] void ternary_g_nd1( [[kernel]] void ternary_g_nd1(
device const bool* a, device const bool* a,

View File

@@ -11,6 +11,7 @@
#define instantiate_ternary_all(op, tname, type) \ #define instantiate_ternary_all(op, tname, type) \
instantiate_kernel("v_" #op #tname, ternary_v, type, op) \ instantiate_kernel("v_" #op #tname, ternary_v, type, op) \
instantiate_kernel("v2_" #op #tname, ternary_v2, type, op) \
instantiate_kernel("g_" #op #tname, ternary_g, type, op) \ instantiate_kernel("g_" #op #tname, ternary_g, type, op) \
instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \ instantiate_kernel("g1_" #op #tname, ternary_g_nd1, type, op) \
instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \ instantiate_kernel("g2_" #op #tname, ternary_g_nd2, type, op) \

View File

@@ -8,6 +8,16 @@ template <typename T, typename Op>
out[index] = Op()(in[index]); out[index] = Op()(in[index]);
} }
template <typename T, typename Op>
[[kernel]] void unary_v2(
device const T* in,
device T* out,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
out[offset] = Op()(in[offset]);
}
template <typename T, typename Op> template <typename T, typename Op>
[[kernel]] void unary_g( [[kernel]] void unary_g(
device const T* in, device const T* in,

View File

@@ -5,8 +5,9 @@
#include "mlx/backend/metal/kernels/unary_ops.h" #include "mlx/backend/metal/kernels/unary_ops.h"
#include "mlx/backend/metal/kernels/unary.h" #include "mlx/backend/metal/kernels/unary.h"
#define instantiate_unary_all(op, tname, type) \ #define instantiate_unary_all(op, tname, type) \
instantiate_kernel("v" #op #tname, unary_v, type, op) \ instantiate_kernel("v" #op #tname, unary_v, type, op) \
instantiate_kernel("v2" #op #tname, unary_v2, type, op) \
instantiate_kernel("g" #op #tname, unary_g, type, op) instantiate_kernel("g" #op #tname, unary_g, type, op)
#define instantiate_unary_float(op) \ #define instantiate_unary_float(op) \

View File

@@ -11,187 +11,14 @@
#include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h" #include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/matmul.h" #include "mlx/backend/metal/matmul.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
///////////////////////////////////////////////////////////////////////////////
// MPS Matmul fallback
///////////////////////////////////////////////////////////////////////////////
namespace { namespace {
bool use_mps() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_USE_MPS")) {
return std::string(buff_str) != "OFF";
} else {
return false;
}
};
static bool use_mps_ = get_val();
return use_mps_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
inline void mps_matmul(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int batch_size_out,
int lda,
int ldb,
bool transpose_a,
bool transpose_b,
std::vector<array>& copies,
float alpha = 1.0f,
float beta = 0.0f) {
MPS::DataType mps_dtype = MPS::DataTypeFloat32;
if (out.dtype() == float16) {
mps_dtype = MPS::DataTypeFloat16;
} else if (out.dtype() == bfloat16) {
mps_dtype = MPS::DataTypeBFloat16;
}
// Used batched MPSMatrixMultiplication if batch_size_out > 1
// We only accept the following cases:
// 1. Both a, b have batch_size_out matrices worth of data
// 2. Only one of a or b has batch_size_out matrices worth of data and
// the other has matrix worth of data
// The matrix dimensions of a and b are sure to be regularly strided
if (batch_size_out > 1) {
// No broadcasting defaults
auto batch_size_a = a.data_size() / (M * K);
auto batch_size_b = b.data_size() / (K * N);
auto matrix_stride_a = M * K;
auto matrix_stride_b = K * N;
auto matrix_stride_out = M * N;
// At this point, batch_size_a, batch_size_b show the number of matrices
// in data, no broadcasted strides considered
if (batch_size_out == std::max(batch_size_a, batch_size_b)) {
// Handle simple broadcasting
if (std::min(batch_size_a, batch_size_b) == 1) {
matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a;
matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b;
batch_size_a = batch_size_out;
batch_size_b = batch_size_out;
}
// Only proceed if broadcasting between a and b is simple
// At this point, batch_size_a, batch_size_b show the number of matrices
// after broadcasting
if (batch_size_a == batch_size_b) {
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
(M * K) / lda,
lda,
batch_size_a,
lda * a.itemsize(),
(matrix_stride_a * a.itemsize()),
mps_dtype);
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
(K * N) / ldb,
ldb,
batch_size_b,
ldb * b.itemsize(),
(matrix_stride_b * b.itemsize()),
mps_dtype);
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
M,
N,
batch_size_out,
N * out.itemsize(),
matrix_stride_out * out.itemsize(),
mps_dtype);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
auto kernel = MPS::MatrixMultiplication::alloc()->init(
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
auto command_buffer = d.get_command_buffer(s.index);
kernel->setBatchSize(batch_size_out);
kernel->setBatchStart(0);
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
command_buffer->addCompletedHandler(
[a_mat, b_mat, out_mat, kernel, copies](
MTL::CommandBuffer*) mutable {
a_mat->release();
b_mat->release();
out_mat->release();
kernel->release();
copies.clear();
});
return;
}
}
}
// Schedule as many calls to MPSMatrixMultiplication as needed otherwise
auto a_desc = MPS::MatrixDescriptor::matrixDescriptor(
a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype);
auto b_desc = MPS::MatrixDescriptor::matrixDescriptor(
b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype);
auto out_desc = MPS::MatrixDescriptor::matrixDescriptor(
batch_size_out * M, N, N * out.itemsize(), mps_dtype);
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc);
auto b_buf = static_cast<const MTL::Buffer*>(b.buffer().ptr());
auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc);
auto out_buf = static_cast<MTL::Buffer*>(out.buffer().ptr());
auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc);
auto kernel = MPS::MatrixMultiplication::alloc()->init(
d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta);
auto command_buffer = d.get_command_buffer(s.index);
for (int i = 0; i < batch_size_out; ++i) {
auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda;
auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb;
kernel->setLeftMatrixOrigin({a_row, 0, 0});
kernel->setRightMatrixOrigin({b_row, 0, 0});
kernel->setResultMatrixOrigin({i * static_cast<size_t>(M), 0, 0});
kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat);
}
command_buffer->addCompletedHandler(
[a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable {
a_mat->release();
b_mat->release();
out_mat->release();
kernel->release();
copies.clear();
});
}
inline auto collapse_batches(const array& a, const array& b) { inline auto collapse_batches(const array& a, const array& b) {
// Get and check the shape for the batched dims // Get and check the shape for the batched dims
std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2}; std::vector<int> A_bshape{a.shape().begin(), a.shape().end() - 2};
@@ -860,26 +687,6 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Gemm specialization // Gemm specialization
if (use_mps()) {
d.end_encoding(s.index);
return mps_matmul(
s,
d,
a,
b,
out,
M,
N,
K,
batch_size_out,
a_cols,
b_cols,
a_transposed,
b_transposed,
copies);
}
return steel_matmul( return steel_matmul(
/* const Stream& s = */ s, /* const Stream& s = */ s,
/* metal::Device& d = */ d, /* metal::Device& d = */ d,
@@ -1529,8 +1336,22 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
kname << "_nc" << !contiguous_kernel; kname << "_nc" << !contiguous_kernel;
// Encode and dispatch kernel // Encode and dispatch kernel
auto kernel = get_gemv_masked_kernel(
d,
kname.str(),
out,
has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
transpose_mat,
bm,
bn,
sm,
sn,
tm,
tn,
contiguous_kernel);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;

View File

@@ -1,14 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cassert>
#include <sstream>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -48,4 +40,4 @@ void steel_matmul(
std::vector<size_t> A_batch_stride = {}, std::vector<size_t> A_batch_stride = {},
std::vector<size_t> B_batch_stride = {}); std::vector<size_t> B_batch_stride = {});
} // namespace mlx::core } // namespace mlx::core

View File

@@ -75,6 +75,9 @@ std::function<void()> make_task(array arr, bool signal) {
if (!arr.is_tracer()) { if (!arr.is_tracer()) {
arr.detach(); arr.detach();
} }
for (auto& out : outputs) {
out.set_status(array::Status::available);
}
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
d.end_encoding(s.index); d.end_encoding(s.index);

View File

@@ -1,370 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <Metal/Metal.hpp>
#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol)
#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor)
namespace MTL::Private::Class {
_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor);
_MTL_PRIVATE_DEF_CLS(MPSMatrix);
_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor);
_MTL_PRIVATE_DEF_CLS(MPSVector);
_MTL_PRIVATE_DEF_CLS(MPSKernel);
_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication);
_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication);
} // namespace MTL::Private::Class
namespace MTL::Private::Selector {
_MTL_PRIVATE_DEF_SEL(
matrixDescriptorWithRows_columns_rowBytes_dataType,
"matrixDescriptorWithRows:columns:rowBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType,
"matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(rows, "rows");
_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:");
_MTL_PRIVATE_DEF_SEL(
initWithDevice_,
"initWithDevice:transposeLeft:transposeRight:"
"resultRows:resultColumns:interiorColumns:alpha:beta:");
_MTL_PRIVATE_DEF_SEL(
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix,
"encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:");
_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:");
_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:");
_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:");
_MTL_PRIVATE_DEF_SEL(
vectorDescriptorWithLength_dataType,
"vectorDescriptorWithLength:dataType:");
_MTL_PRIVATE_DEF_SEL(
vectorDescriptorWithLength_vectors_vectorBytes_dataType,
"vectorDescriptorWithLength:vectors:vectorBytes:dataType:");
_MTL_PRIVATE_DEF_SEL(
initWithDevice_transpose_rows_columns_alpha_beta,
"initWithDevice:transpose:rows:columns:alpha:beta:");
_MTL_PRIVATE_DEF_SEL(
encodeToCommandBuffer_inputMatrix_inputVector_resultVector,
"encodeToCommandBuffer:inputMatrix:inputVector:resultVector:");
} // namespace MTL::Private::Selector
namespace MPS {
typedef enum DataType : uint32_t {
DataTypeFloatBit = 0x10000000,
DataTypeAlternateEncodingBit = 0x80000000,
DataTypeFloat16 = DataTypeFloatBit | 16,
DataTypeFloat32 = DataTypeFloatBit | 32,
DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16
} DataType;
class MatrixDescriptor : public NS::Copying<MatrixDescriptor> {
public:
static class MatrixDescriptor* matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger rowBytes,
NS::UInteger dataType);
static class MatrixDescriptor* matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger matrices,
NS::UInteger rowBytes,
NS::UInteger matrixBytes,
NS::UInteger dataType);
NS::UInteger rows() const;
};
class Matrix : public NS::Referencing<Matrix> {
public:
static class Matrix* alloc();
Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor);
Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor);
};
class Kernel : public NS::Referencing<Kernel> {
public:
NS::String* label() const;
MTL::Device* device() const;
};
class MatrixMultiplication
: public NS::Referencing<MatrixMultiplication, Kernel> {
public:
static class MatrixMultiplication* alloc();
MatrixMultiplication* init(
MTL::Device* device,
bool transposeLeft,
bool transposeRight,
NS::UInteger resultRows,
NS::UInteger resultColumns,
NS::UInteger interiorColumns,
double alpha,
double beta);
void encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* leftMatrix,
Matrix* rightMatrix,
Matrix* resultMatrix);
void setLeftMatrixOrigin(MTL::Origin origin);
void setRightMatrixOrigin(MTL::Origin origin);
void setResultMatrixOrigin(MTL::Origin origin);
void setBatchStart(NS::UInteger batchStart);
void setBatchSize(NS::UInteger batchSize);
};
class VectorDescriptor : public NS::Copying<VectorDescriptor> {
public:
static class VectorDescriptor* vectorDescriptor(
NS::UInteger length,
NS::UInteger dataType);
static class VectorDescriptor* vectorDescriptor(
NS::UInteger length,
NS::UInteger vectors,
NS::UInteger vectorBytes,
NS::UInteger dataType);
};
class Vector : public NS::Referencing<Vector> {
public:
static class Vector* alloc();
Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor);
Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor);
};
class MatrixVectorMultiplication
: public NS::Referencing<MatrixVectorMultiplication, Kernel> {
public:
static class MatrixVectorMultiplication* alloc();
MatrixVectorMultiplication* init(
MTL::Device* device,
bool transpose,
NS::UInteger rows,
NS::UInteger columns,
double alpha,
double beta);
void encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* inputMatrix,
Vector* inputVector,
Vector* resultVector);
};
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger rowBytes,
NS::UInteger dataType) {
return Object::sendMessage<MatrixDescriptor*>(
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
_MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType),
rows,
columns,
rowBytes,
dataType);
}
_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor(
NS::UInteger rows,
NS::UInteger columns,
NS::UInteger matrices,
NS::UInteger rowBytes,
NS::UInteger matrixBytes,
NS::UInteger dataType) {
return Object::sendMessage<MatrixDescriptor*>(
_MPS_PRIVATE_CLS(MPSMatrixDescriptor),
_MPS_PRIVATE_SEL(
matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType),
rows,
columns,
matrices,
rowBytes,
matrixBytes,
dataType);
}
_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const {
return Object::sendMessage<NS::UInteger>(this, _MPS_PRIVATE_SEL(rows));
}
_MTL_INLINE Matrix* Matrix::alloc() {
return NS::Object::alloc<Matrix>(_MPS_PRIVATE_CLS(MPSMatrix));
}
_MTL_INLINE Matrix* Matrix::init(
MTL::Buffer* buffer,
MatrixDescriptor* descriptor) {
return Object::sendMessage<Matrix*>(
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
}
_MTL_INLINE Matrix* Matrix::init(
const MTL::Buffer* buffer,
MatrixDescriptor* descriptor) {
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
}
_MTL_INLINE NS::String* Kernel::label() const {
return Object::sendMessage<NS::String*>(this, _MPS_PRIVATE_SEL(label));
}
_MTL_INLINE MTL::Device* Kernel::device() const {
return Object::sendMessage<MTL::Device*>(this, _MPS_PRIVATE_SEL(device));
}
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() {
return NS::Object::alloc<MatrixMultiplication>(
_MPS_PRIVATE_CLS(MPSMatrixMultiplication));
}
_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init(
MTL::Device* device,
bool transposeLeft,
bool transposeRight,
NS::UInteger resultRows,
NS::UInteger resultColumns,
NS::UInteger interiorColumns,
double alpha,
double beta) {
return Object::sendMessage<MatrixMultiplication*>(
this,
_MPS_PRIVATE_SEL(initWithDevice_),
device,
transposeLeft,
transposeRight,
resultRows,
resultColumns,
interiorColumns,
alpha,
beta);
}
_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* leftMatrix,
Matrix* rightMatrix,
Matrix* resultMatrix) {
return Object::sendMessage<void>(
this,
_MPS_PRIVATE_SEL(
encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix),
commandBuffer,
leftMatrix,
rightMatrix,
resultMatrix);
}
_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin(
MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin(
MTL::Origin origin) {
Object::sendMessage<void>(
this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin);
}
_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) {
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart);
}
_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) {
Object::sendMessage<void>(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize);
}
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
NS::UInteger length,
NS::UInteger dataType) {
return Object::sendMessage<VectorDescriptor*>(
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType),
length,
dataType);
}
_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor(
NS::UInteger length,
NS::UInteger vectors,
NS::UInteger vectorBytes,
NS::UInteger dataType) {
return Object::sendMessage<VectorDescriptor*>(
_MPS_PRIVATE_CLS(MPSVectorDescriptor),
_MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType),
length,
vectors,
vectorBytes,
dataType);
}
_MTL_INLINE Vector* Vector::alloc() {
return NS::Object::alloc<Vector>(_MPS_PRIVATE_CLS(MPSVector));
}
_MTL_INLINE Vector* Vector::init(
MTL::Buffer* buffer,
VectorDescriptor* descriptor) {
return Object::sendMessage<Vector*>(
this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor);
}
_MTL_INLINE Vector* Vector::init(
const MTL::Buffer* buffer,
VectorDescriptor* descriptor) {
return init(const_cast<MTL::Buffer*>(buffer), descriptor);
}
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() {
return NS::Object::alloc<MatrixVectorMultiplication>(
_MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication));
}
_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init(
MTL::Device* device,
bool transpose,
NS::UInteger rows,
NS::UInteger columns,
double alpha,
double beta) {
return Object::sendMessage<MatrixVectorMultiplication*>(
this,
_MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta),
device,
transpose,
rows,
columns,
alpha,
beta);
}
_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer(
MTL::CommandBuffer* commandBuffer,
Matrix* inputMatrix,
Vector* inputVector,
Vector* resultVector) {
return Object::sendMessage<void>(
this,
_MPS_PRIVATE_SEL(
encodeToCommandBuffer_inputMatrix_inputVector_resultVector),
commandBuffer,
inputMatrix,
inputVector,
resultVector);
}
} // namespace MPS

View File

@@ -169,6 +169,23 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name); return d.get_kernel(kernel_name);
} }
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array&,
const std::optional<array>&,
const std::optional<array>&,
bool,
int,
int,
int,
int,
int,
int,
bool) {
return d.get_kernel(kernel_name);
}
MTL::ComputePipelineState* get_steel_conv_kernel( MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,

View File

@@ -171,7 +171,7 @@ void Copy::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out); eval(inputs, out);
} }
void CustomVJP::eval_gpu( void CustomTransforms::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
eval(inputs, outputs); eval(inputs, outputs);
@@ -273,7 +273,18 @@ void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out); auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) { if (copy_necessary) {
copy_gpu(in, out, CopyType::General); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto out_strides = make_contiguous_strides<size_t>(in.shape());
copy_gpu_inplace(
in,
out,
in.shape(),
in.strides(),
out_strides,
0,
0,
CopyType::General,
stream());
} else { } else {
shared_buffer_reshape(in, out_strides, out); shared_buffer_reshape(in, out_strides, out);
} }

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
namespace mlx::core { namespace mlx::core {
@@ -47,8 +48,8 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname; std::ostringstream kname;
auto type_string = get_type_string(x.dtype()); auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_ kname << "qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
<< "_fast"; << bits_;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -270,8 +271,8 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname; std::ostringstream kname;
auto type_string = get_type_string(x.dtype()); auto type_string = get_type_string(x.dtype());
kname << "bs_qmv_" << type_string << "_gs_" << group_size_ << "_b_" kname << "bs_qmv_fast_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_fast"; << bits_;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -513,4 +514,82 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
void fast::AffineQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
bool compute_scale_bias = inputs.size() == 1;
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::vector<array> copies;
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto w = ensure_row_contiguous(w_pre);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(w, 0);
if (!compute_scale_bias) {
auto& scales_pre = inputs[1];
auto& biases_pre = inputs[2];
auto scales = ensure_row_contiguous(scales_pre);
auto biases = ensure_row_contiguous(biases_pre);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(biases, 2);
compute_encoder.set_output_array(out, 3);
} else {
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
compute_encoder.set_output_array(out, 1);
compute_encoder.set_output_array(scales, 2);
compute_encoder.set_output_array(biases, 3);
}
std::ostringstream kname;
auto type_string = dequantize_ ? get_type_string(out.dtype())
: get_type_string(w_pre.dtype());
auto kernel_func = "affine_quantize_scales_biases";
if (dequantize_) {
kernel_func = "affine_dequantize";
} else if (compute_scale_bias) {
kernel_func = "affine_quantize";
}
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_;
auto template_def = get_template_definition(
kname.str(), kernel_func, type_string, group_size_, bits_);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel);
// Treat uint32 as uint8 in kernel
constexpr int uint8_per_uint32 = 4;
constexpr int simd_size = 32;
int packs_per_int = 8 / bits_;
int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int;
size_t nthreads =
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto group_dims = MTL::Size(thread_group_size, 1, 1);
auto grid_dims = MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -5,6 +5,8 @@
namespace mlx::core::fast { namespace mlx::core::fast {
constexpr int n_per_thread = 4;
void RoPE::eval_gpu( void RoPE::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
@@ -62,8 +64,11 @@ void RoPE::eval_gpu(
out_strides[1] = out.strides()[ndim - 2]; out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1]; out_strides[2] = out.strides()[ndim - 1];
// Special case for inference (single time step and contiguous)
bool single = in.flags().row_contiguous && (mat_size == in.shape(-1));
std::ostringstream kname; std::ostringstream kname;
kname << "rope_" << (forward_ ? "" : "vjp_") kname << "rope_" << (single ? "single_" : "") << (forward_ ? "" : "vjp_")
<< (traditional_ ? "traditional_" : "") << type_to_name(in); << (traditional_ ? "traditional_" : "") << type_to_name(in);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -72,18 +77,28 @@ void RoPE::eval_gpu(
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(donated ? out : in, 0); compute_encoder.set_input_array(donated ? out : in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 2); compute_encoder->setBytes(&offset_, sizeof(int), 2);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 3); compute_encoder->setBytes(&base, sizeof(float), 3);
compute_encoder->setBytes(&offset_, sizeof(int), 4); compute_encoder->setBytes(&scale_, sizeof(float), 4);
compute_encoder->setBytes(&base, sizeof(float), 5);
compute_encoder->setBytes(&scale_, sizeof(float), 6);
int dim0 = dims_ / 2; size_t n_batch = in.size() / mat_size;
int dim1 = in.shape(-2); if (single) {
int dim2 = in.size() / mat_size; compute_encoder->setBytes(&out_strides[1], sizeof(size_t), 5);
auto group_dims = get_block_dims(dim0, dim1, dim2); uint32_t dim0 = dims_ / 2;
auto grid_dims = MTL::Size(dim0, dim1, dim2); auto group_dims = get_block_dims(dim0, n_batch, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); auto grid_dims = MTL::Size(dim0, n_batch, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
compute_encoder->setBytes(&strides, 3 * sizeof(size_t), 5);
compute_encoder->setBytes(&out_strides, 3 * sizeof(size_t), 6);
compute_encoder->setBytes(&n_batch, sizeof(size_t), 7);
uint32_t dim0 = dims_ / 2;
uint32_t dim1 = in.shape(-2);
uint32_t dim2 = (n_batch + n_per_thread - 1) / n_per_thread;
auto group_dims = get_block_dims(dim0, dim1, dim2);
auto grid_dims = MTL::Size(dim0, dim1, dim2);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <algorithm> #include <algorithm>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/copy.h"

View File

@@ -24,8 +24,11 @@ void single_block_sort(
// Prepare shapes // Prepare shapes
int n_rows = in.size() / in.shape(axis); int n_rows = in.size() / in.shape(axis);
std::vector<size_t> nc_str = in.strides(); std::vector<size_t> in_nc_str = in.strides();
nc_str.erase(nc_str.begin() + axis); in_nc_str.erase(in_nc_str.begin() + axis);
std::vector<size_t> out_nc_str = out.strides();
out_nc_str.erase(out_nc_str.begin() + axis);
std::vector<int> nc_shape = in.shape(); std::vector<int> nc_shape = in.shape();
nc_shape.erase(nc_shape.begin() + axis); nc_shape.erase(nc_shape.begin() + axis);
@@ -33,21 +36,28 @@ void single_block_sort(
int nc_dim = nc_shape.size(); int nc_dim = nc_shape.size();
int size_sorted_axis = in.shape(axis); int size_sorted_axis = in.shape(axis);
int stride_sorted_axis = in.strides()[axis]; int in_stride_sorted_axis = in.strides()[axis];
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end()); int out_stride_sorted_axis = out.strides()[axis];
int in_stride_segment_axis =
*std::min_element(in_nc_str.begin(), in_nc_str.end());
int out_stride_segment_axis =
*std::min_element(out_nc_str.begin(), out_nc_str.end());
// Check if remaining strides are contiguous // We can only use the contiguous kernel if the sorted axis
bool contiguous_write = true; // has the largest or smallest stride.
if (axis != in.ndim() - 1 && axis != 0) { // We also need the input to be contiguous
for (int i = 0; i < nc_str.size() - 1; ++i) { bool contiguous = in.flags().contiguous;
size_t expected = nc_str[i + 1] * nc_str[i + 1]; auto check_strides = [](array x, int sort_stride) {
contiguous_write &= (nc_str[i] == expected); int min_stride = *std::min_element(x.strides().begin(), x.strides().end());
} int max_stride = *std::max_element(x.strides().begin(), x.strides().end());
} return sort_stride == min_stride || sort_stride == max_stride;
};
contiguous &= check_strides(in, in_stride_sorted_axis);
contiguous &= check_strides(out, out_stride_sorted_axis);
// Prepare kernel name // Prepare kernel name
std::ostringstream kname; std::ostringstream kname;
kname << (contiguous_write ? "c" : "nc"); kname << (contiguous ? "c" : "nc");
if (argsort) { if (argsort) {
kname << "arg"; kname << "arg";
} }
@@ -64,14 +74,17 @@ void single_block_sort(
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3); compute_encoder->setBytes(&in_stride_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&out_stride_sorted_axis, sizeof(int), 4);
if (contiguous_write) { if (contiguous) {
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4); compute_encoder->setBytes(&in_stride_segment_axis, sizeof(int), 5);
compute_encoder->setBytes(&out_stride_segment_axis, sizeof(int), 6);
} else { } else {
compute_encoder->setBytes(&nc_dim, sizeof(int), 4); compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5); compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6); compute_encoder->setBytes(in_nc_str.data(), nc_dim * sizeof(size_t), 7);
compute_encoder->setBytes(out_nc_str.data(), nc_dim * sizeof(size_t), 8);
} }
MTL::Size group_dims = MTL::Size(bn, 1, 1); MTL::Size group_dims = MTL::Size(bn, 1, 1);
@@ -164,6 +177,8 @@ void multi_block_sort(
array dev_vals_out = dev_vals_1; array dev_vals_out = dev_vals_1;
array dev_idxs_out = dev_idxs_1; array dev_idxs_out = dev_idxs_1;
int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;
for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) { for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
dev_vals_in = ping ? dev_vals_1 : dev_vals_0; dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0; dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
@@ -186,8 +201,9 @@ void multi_block_sort(
compute_encoder.set_input_array(dev_idxs_in, 2); compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3); compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4); compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1); MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1); MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);

View File

@@ -32,6 +32,7 @@ void ternary_op_gpu_inplace(
auto& strides_c = strides[2]; auto& strides_c = strides[2];
auto& strides_out = strides[3]; auto& strides_out = strides[3];
bool use_2d = out.data_size() > UINT_MAX;
std::string kernel_name; std::string kernel_name;
{ {
std::ostringstream kname; std::ostringstream kname;
@@ -40,6 +41,8 @@ void ternary_op_gpu_inplace(
if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) { if (shape.size() <= MAX_TERNARY_SPECIALIZED_DIMS) {
kname << shape.size(); kname << shape.size();
} }
} else if (use_2d) {
kname << "v2";
} else { } else {
kname << "v"; kname << "v";
} }

View File

@@ -25,11 +25,14 @@ void unary_op_gpu_inplace(
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out); size_t nthreads = contig ? in.data_size() : in.size();
bool use_2d = nthreads > UINT32_MAX;
std::string kernel_name =
(contig ? (use_2d ? "v2" : "v") : "g") + op + type_to_name(out);
auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op); auto kernel = get_unary_kernel(d, kernel_name, out.dtype(), op);
size_t nthreads = contig ? in.data_size() : in.size(); MTL::Size grid_dims = use_2d ? get_2d_grid_dims(in.shape(), in.strides())
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;

116
mlx/backend/metal/utils.cpp Normal file
View File

@@ -0,0 +1,116 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/utils.h"
using namespace mlx;
namespace mlx::core {
std::string type_to_name(const array& a) {
std::string tname;
switch (a.dtype()) {
case bool_:
tname = "bool_";
break;
case uint8:
tname = "uint8";
break;
case uint16:
tname = "uint16";
break;
case uint32:
tname = "uint32";
break;
case uint64:
tname = "uint64";
break;
case int8:
tname = "int8";
break;
case int16:
tname = "int16";
break;
case int32:
tname = "int32";
break;
case int64:
tname = "int64";
break;
case float16:
tname = "float16";
break;
case float32:
tname = "float32";
break;
case bfloat16:
tname = "bfloat16";
break;
case complex64:
tname = "complex64";
break;
}
return tname;
}
MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
int presum = sum;
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) {
pows[0]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim1 >= (1 << (pows[1] + 1))) {
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == 10) {
break;
}
}
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
}
MTL::Size get_2d_grid_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
// Dims with strides of 0 are ignored as they
// correspond to broadcasted dimensions
size_t grid_x = 1;
size_t grid_y = 1;
for (int i = 0; i < shape.size(); ++i) {
if (strides[i] == 0) {
continue;
}
if (grid_x * shape[i] < UINT32_MAX) {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
}
if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
throw std::runtime_error("Unable to safely factor shape.");
}
return MTL::Size(
static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
}
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
} // namespace mlx::core

View File

@@ -8,8 +8,6 @@
namespace mlx::core { namespace mlx::core {
namespace {
using metal::CommandEncoder; using metal::CommandEncoder;
template <typename T> template <typename T>
@@ -27,82 +25,22 @@ set_vector_bytes(CommandEncoder& enc, const std::vector<T>& vec, int idx) {
return set_vector_bytes(enc, vec, vec.size(), idx); return set_vector_bytes(enc, vec, vec.size(), idx);
} }
std::string type_to_name(const array& a) { std::string type_to_name(const array& a);
std::string tname;
switch (a.dtype()) {
case bool_:
tname = "bool_";
break;
case uint8:
tname = "uint8";
break;
case uint16:
tname = "uint16";
break;
case uint32:
tname = "uint32";
break;
case uint64:
tname = "uint64";
break;
case int8:
tname = "int8";
break;
case int16:
tname = "int16";
break;
case int32:
tname = "int32";
break;
case int64:
tname = "int64";
break;
case float16:
tname = "float16";
break;
case float32:
tname = "float32";
break;
case bfloat16:
tname = "bfloat16";
break;
case complex64:
tname = "complex64";
break;
}
return tname;
}
MTL::Size get_block_dims(int dim0, int dim1, int dim2) { // Compute the thread block dimensions which fit the given
int pows[3] = {0, 0, 0}; // input dimensions.
int sum = 0; // - The thread block dimensions will be powers of two
while (true) { // - The thread block size will be less than 1024
int presum = sum; MTL::Size get_block_dims(int dim0, int dim1, int dim2);
// Check all the pows
if (dim0 >= (1 << (pows[0] + 1))) { // Computes a 2D grid where each element is < UINT_MAX
pows[0]++; // Assumes:
sum++; // - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
} // - shape and strides correspond to a contiguous (no holes) but
if (sum == 10) { // possibly broadcasted array
break; MTL::Size get_2d_grid_dims(
} const std::vector<int>& shape,
if (dim1 >= (1 << (pows[1] + 1))) { const std::vector<size_t>& strides);
pows[1]++;
sum++;
}
if (sum == 10) {
break;
}
if (dim2 >= (1 << (pows[2] + 1))) {
pows[2]++;
sum++;
}
if (sum == presum || sum == 10) {
break;
}
}
return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]};
}
inline NS::String* make_string(std::ostringstream& os) { inline NS::String* make_string(std::ostringstream& os) {
std::string string = os.str(); std::string string = os.str();
@@ -130,23 +68,6 @@ inline void debug_set_primitive_buffer_label(
#endif #endif
} }
bool is_power_of_2(int n) { std::string get_primitive_string(Primitive* primitive);
return ((n & (n - 1)) == 0) && n != 0;
}
int next_power_of_2(int n) {
if (is_power_of_2(n)) {
return n;
}
return pow(2, std::ceil(std::log2(n)));
}
std::string get_primitive_string(Primitive* primitive) {
std::ostringstream op_t;
primitive->print(op_t);
return op_t.str();
}
} // namespace
} // namespace mlx::core } // namespace mlx::core

View File

@@ -42,7 +42,7 @@ NO_CPU(Convolution)
NO_CPU(Copy) NO_CPU(Copy)
NO_CPU(Cos) NO_CPU(Cos)
NO_CPU(Cosh) NO_CPU(Cosh)
NO_CPU_MULTI(CustomVJP) NO_CPU_MULTI(CustomTransforms)
NO_CPU_MULTI(Depends) NO_CPU_MULTI(Depends)
NO_CPU(Divide) NO_CPU(Divide)
NO_CPU_MULTI(DivMod) NO_CPU_MULTI(DivMod)
@@ -61,6 +61,7 @@ NO_CPU(GatherMM)
NO_CPU(GatherQMM) NO_CPU(GatherQMM)
NO_CPU(Greater) NO_CPU(Greater)
NO_CPU(GreaterEqual) NO_CPU(GreaterEqual)
NO_CPU(Hadamard)
NO_CPU(Less) NO_CPU(Less)
NO_CPU(LessEqual) NO_CPU(LessEqual)
NO_CPU(Load) NO_CPU(Load)

View File

@@ -43,7 +43,7 @@ NO_GPU(Convolution)
NO_GPU(Copy) NO_GPU(Copy)
NO_GPU(Cos) NO_GPU(Cos)
NO_GPU(Cosh) NO_GPU(Cosh)
NO_GPU_MULTI(CustomVJP) NO_GPU_MULTI(CustomTransforms)
NO_GPU_MULTI(Depends) NO_GPU_MULTI(Depends)
NO_GPU(Divide) NO_GPU(Divide)
NO_GPU_MULTI(DivMod) NO_GPU_MULTI(DivMod)
@@ -62,6 +62,7 @@ NO_GPU(GatherMM)
NO_GPU(GatherQMM) NO_GPU(GatherQMM)
NO_GPU(Greater) NO_GPU(Greater)
NO_GPU(GreaterEqual) NO_GPU(GreaterEqual)
NO_GPU(Hadamard)
NO_GPU(Less) NO_GPU(Less)
NO_GPU(LessEqual) NO_GPU(LessEqual)
NO_GPU(Load) NO_GPU(Load)
@@ -117,6 +118,7 @@ NO_GPU_MULTI(RMSNorm)
NO_GPU_MULTI(RMSNormVJP) NO_GPU_MULTI(RMSNormVJP)
NO_GPU_MULTI(RoPE) NO_GPU_MULTI(RoPE)
NO_GPU(ScaledDotProductAttention) NO_GPU(ScaledDotProductAttention)
NO_GPU_MULTI(AffineQuantize)
} // namespace fast } // namespace fast
} // namespace mlx::core } // namespace mlx::core

View File

@@ -266,6 +266,10 @@ class CompilerCache {
cache_.erase(fun_id); cache_.erase(fun_id);
} }
void clear() {
cache_.clear();
}
private: private:
CompilerCache() { CompilerCache() {
// Make sure the allocator is fully // Make sure the allocator is fully
@@ -859,6 +863,10 @@ void compile_erase(std::uintptr_t fun_id) {
detail::compiler_cache().erase(fun_id); detail::compiler_cache().erase(fun_id);
} }
void compile_clear_cache() {
detail::compiler_cache().clear();
}
} // namespace detail } // namespace detail
std::function<std::vector<array>(const std::vector<array>&)> compile( std::function<std::vector<array>(const std::vector<array>&)> compile(

View File

@@ -1,11 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cstdint> #include <cstdint>
#include <sstream>
#include <vector>
#include "mlx/dtype.h" #include "mlx/dtype.h"
#include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -178,67 +175,4 @@ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
[static_cast<uint32_t>(b)]; [static_cast<uint32_t>(b)];
} }
// Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t) {
std::ostringstream r;
if (size_of(t) > 1)
r << (is_big_endian() ? ">" : "<");
else
r << "|";
r << kindof(t) << (int)size_of(t);
return r.str();
}
// Dtype from array protocol type string
Dtype dtype_from_array_protocol(std::string_view t) {
if (t.length() == 2 || t.length() == 3) {
std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
if (r == "V2") {
return bfloat16;
}
uint8_t size = r[1] - '0';
switch (r[0]) {
case 'b': {
if (size == 1)
return bool_;
}
case 'i': {
if (size == 1)
return int8;
else if (size == 2)
return int16;
else if (size == 4)
return int32;
else if (size == 8)
return int64;
}
case 'u': {
if (size == 1)
return uint8;
else if (size == 2)
return uint16;
else if (size == 4)
return uint32;
else if (size == 8)
return uint64;
}
case 'f': {
if (size == 2)
return float16;
else if (size == 4)
return float32;
}
case 'c': {
return complex64;
}
}
}
throw std::invalid_argument(
"[from_str] Invalid array protocol type-string: " + std::string(t));
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -4,8 +4,6 @@
#include <complex> #include <complex>
#include <cstdint> #include <cstdint>
#include <ostream>
#include <string>
#include "mlx/types/complex.h" #include "mlx/types/complex.h"
#include "mlx/types/half_types.h" #include "mlx/types/half_types.h"
@@ -103,9 +101,4 @@ struct TypeToDtype {
operator Dtype(); operator Dtype();
}; };
// Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t);
// Dtype from array protocol type string
Dtype dtype_from_array_protocol(std::string_view t);
} // namespace mlx::core } // namespace mlx::core

859
mlx/einsum.cpp Normal file
View File

@@ -0,0 +1,859 @@
// Copyright © 2024 Apple Inc.
#include <numeric>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include "mlx/einsum.h"
#include "mlx/ops.h"
namespace mlx::core {
namespace {
// The MLX einsum implementation is based on NumPy (which is based on
// opt_einsum):
// https://github.com/numpy/numpy/blob/1d49c7f7ff527c696fc26ab2278ad51632a66660/numpy/_core/einsumfunc.py#L743
// https://github.com/dgasmith/opt_einsum
using CharSet = std::unordered_set<char>;
// A helper struct to hold the string and set
// representation of a subscript to avoid needing
// to recompute the set
struct Subscript {
Subscript(std::string str, CharSet set)
: str(std::move(str)), set(std::move(set)) {};
std::string str;
CharSet set;
};
struct PathInfo {
size_t naive_cost;
size_t naive_scaling;
size_t optimized_cost;
size_t optimized_scaling;
size_t largest_term;
};
struct PathNode {
PathNode(
std::vector<Subscript> inputs,
Subscript output,
std::vector<int> positions)
: inputs(std::move(inputs)),
output(std::move(output)),
positions(std::move(positions)) {};
std::vector<Subscript> inputs;
Subscript output;
std::vector<int> positions;
};
// Parse the comma separated subscripts into a vector of strings. If the
// output subscripts are missing they are inferred.
//
// For example:
// "ij,jk -> ik" becomes {{"ij", "jk"}, "ik"}
// "ij,jk" becomes {{"ij", "jk"}, "ik"}
std::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {
std::string lhs, rhs;
// Start by removing all white space
subscripts.erase(
std::remove(subscripts.begin(), subscripts.end(), ' '), subscripts.end());
if (auto pos = subscripts.find("->"); pos != std::string::npos) {
// Explicit mode
lhs = subscripts.substr(0, pos);
rhs = subscripts.substr(pos + 2);
} else {
// Implicit mode:
// - repeats are summed
// - remaining output axes are ordered alphabetically
lhs = subscripts;
std::unordered_map<char, int> temp;
for (auto& c : subscripts) {
if (c == ',') {
continue;
}
auto inserted = temp.insert({c, 0});
inserted.first->second++;
}
for (auto& k : temp) {
if (k.second == 1) {
rhs += k.first;
}
}
std::sort(rhs.begin(), rhs.end());
}
std::vector<std::string> input_list;
std::stringstream ss(lhs);
std::string token;
while (getline(ss, token, ',')) {
input_list.push_back(token);
}
return {input_list, rhs};
}
// Check if two sets are disjoint
bool disjoint(const CharSet& x, const CharSet& y) {
for (auto& c : x) {
if (y.find(c) != y.end()) {
return false;
}
}
return true;
}
template <typename T>
size_t term_size(const T& term, std::unordered_map<char, int> dict) {
size_t size = 1;
for (auto c : term) {
size *= dict[c];
}
return size;
}
size_t flop_count(
const CharSet& term,
bool inner,
int num_terms,
std::unordered_map<char, int> dict) {
size_t size = term_size(term, dict);
auto op_factor = 1;
if ((num_terms - 1) > op_factor) {
op_factor = num_terms - 1;
}
if (inner) {
op_factor += 1;
}
return size * op_factor;
}
std::pair<size_t, int> compute_cost_and_scaling(
const std::vector<Subscript>& inputs,
const Subscript& output,
std::unordered_map<char, int> dim_map) {
CharSet contractions;
for (auto& in : inputs) {
contractions.insert(in.set.begin(), in.set.end());
}
bool inner = false;
for (auto c : contractions) {
if (output.set.find(c) == output.set.end()) {
inner = true;
break;
}
}
auto cost = flop_count(contractions, inner, inputs.size(), dim_map);
return {cost, contractions.size()};
}
std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
std::vector<Subscript> inputs,
const Subscript& output,
std::unordered_map<char, int> dim_map,
size_t cost_limit,
size_t memory_limit) {
// Helper struct for building the greedy path
struct Contraction {
Contraction(
size_t size,
size_t cost,
CharSet output,
int dims,
int x,
int y)
: size(size),
cost(cost),
output(std::move(output)),
dims(dims),
x(x),
y(y) {};
int64_t size; // Size difference, can be negative
size_t cost;
CharSet output;
int dims; // Number of dimensions in the contraction
int x;
int y;
};
// Start by iterating over all possible combinations
std::vector<std::pair<int, int>> pos_pairs;
for (int i = 0; i < inputs.size(); ++i) {
for (int j = i + 1; j < inputs.size(); ++j) {
pos_pairs.emplace_back(i, j);
}
}
std::vector<PathNode> path;
std::vector<Contraction> possible_contractions;
size_t path_cost = 0;
int path_scaling = 0;
auto num_in = inputs.size();
for (int i = 0; i < num_in - 1; ++i) {
auto add_contraction = [&](int p1, int p2) {
CharSet new_term;
CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end());
contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end());
for (int i = 0; i < inputs.size(); i++) {
if (i == p1 || i == p2) {
continue;
}
auto& in = inputs[i].set;
for (auto c : in) {
if (contractions.find(c) != contractions.end()) {
new_term.insert(c);
}
}
}
for (auto c : output.set) {
if (contractions.find(c) != contractions.end()) {
new_term.insert(c);
}
}
// Ignore if:
// - The size of the new result is greater than the memory limit
// - The cost is larger than the naive cost
auto new_size = term_size(new_term, dim_map);
if (new_size > memory_limit) {
return;
}
int64_t removed_size = term_size(inputs[p1].set, dim_map) +
term_size(inputs[p2].set, dim_map) - new_size;
bool inner = contractions.size() > new_term.size();
auto cost = flop_count(contractions, inner, 2, dim_map);
if (path_cost + cost > cost_limit) {
return;
}
possible_contractions.emplace_back(
removed_size, cost, std::move(new_term), contractions.size(), p1, p2);
};
for (auto& [p1, p2] : pos_pairs) {
// Ignore outer products
if (!disjoint(inputs[p1].set, inputs[p2].set)) {
add_contraction(p1, p2);
}
}
// If there's nothing in the contraction list,
// go over the pairs again without ignoring outer products
if (possible_contractions.empty()) {
for (auto& [p1, p2] : pos_pairs) {
add_contraction(p1, p2);
}
}
if (possible_contractions.empty()) {
// Default to naive einsum for the remaining inputs
std::vector<int> positions(inputs.size());
std::iota(positions.begin(), positions.end(), 0);
auto [cost, scale] = compute_cost_and_scaling(inputs, output, dim_map);
path.emplace_back(std::move(inputs), output, std::move(positions));
path_cost += cost;
path_scaling = std::max(scale, path_scaling);
break;
}
// Find the best contraction
auto& best = *std::min_element(
possible_contractions.begin(),
possible_contractions.end(),
[](const auto& x, const auto& y) {
return x.size > y.size || (x.size == y.size && x.cost < y.cost);
});
path_scaling = std::max(best.dims, path_scaling);
// Construct the output subscripts
std::string out_str(best.output.begin(), best.output.end());
// TODO, sorting by dimension size seems suboptimal?
std::sort(out_str.begin(), out_str.end(), [&dim_map](auto x, auto y) {
return dim_map[x] < dim_map[y];
});
Subscript new_output(std::move(out_str), std::move(best.output));
// Add the chosen contraction to the path
{
std::vector<Subscript> in_terms;
in_terms.push_back(std::move(inputs[best.x]));
in_terms.push_back(std::move(inputs[best.y]));
path.emplace_back(
std::move(in_terms), new_output, std::vector<int>{best.x, best.y});
}
// Remove used terms
inputs.erase(inputs.begin() + best.y);
inputs.erase(inputs.begin() + best.x);
// Add the new result
inputs.push_back(std::move(new_output));
// Update the existing contractions based on the selected one
std::vector<Contraction> updated_contractions;
for (auto& contraction : possible_contractions) {
// Drop contractions which contain either selected term
if (contraction.x == best.x || contraction.x == best.y ||
contraction.y == best.x || contraction.y == best.y) {
continue;
}
// Update the positions of other contractions
int x =
contraction.x - (contraction.x > best.x) - (contraction.x > best.y);
int y =
contraction.y - (contraction.y > best.x) - (contraction.y > best.y);
contraction.x = x;
contraction.y = y;
updated_contractions.push_back(std::move(contraction));
}
pos_pairs.clear();
for (int i = 0; i < inputs.size() - 1; ++i) {
pos_pairs.emplace_back(i, inputs.size() - 1);
}
path_cost += best.cost;
possible_contractions = std::move(updated_contractions);
}
return {path, path_cost, path_scaling};
}
// Assumes inputs have already have had repeats and single axis sums collapsed
bool can_dot(const std::vector<Subscript>& inputs, const Subscript& output) {
if (inputs.size() != 2) {
return false;
}
for (auto c : inputs[0].set) {
// Use batched tensordot if anything is being contracted
if (output.set.find(c) == output.set.end()) {
return true;
}
}
return false;
}
array batch_tensordot(
array a,
array b,
std::vector<int> a_contract,
std::vector<int> a_batch,
std::vector<int> a_concat,
std::vector<int> b_contract,
std::vector<int> b_batch,
std::vector<int> b_concat,
StreamOrDevice s) {
// Broadcast contracting dimensions
{
auto a_shape = a.shape();
auto b_shape = b.shape();
for (int i = 0; i < a_contract.size(); ++i) {
auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i]));
a_shape[a_contract[i]] = d;
b_shape[b_contract[i]] = d;
}
a = broadcast_to(a, a_shape, s);
b = broadcast_to(b, b_shape, s);
}
auto transpose_reshape = [&s](
const array& x,
const std::vector<int>& i,
const std::vector<int>& j,
const std::vector<int>& k) {
std::vector<int> reorder(i.begin(), i.end());
reorder.insert(reorder.end(), j.begin(), j.end());
reorder.insert(reorder.end(), k.begin(), k.end());
int size1 = 1;
for (auto s : j) {
size1 *= x.shape(s);
}
int size2 = 1;
for (auto s : k) {
size2 *= x.shape(s);
}
std::vector<int> shape;
for (auto ax : i) {
shape.push_back(x.shape(ax));
}
shape.push_back(size1);
shape.push_back(size2);
return reshape(transpose(x, reorder, s), std::move(shape), s);
};
std::vector<int> out_shape;
for (auto ax : a_batch) {
out_shape.push_back(a.shape(ax));
}
for (auto ax : a_concat) {
out_shape.push_back(a.shape(ax));
}
for (auto ax : b_concat) {
out_shape.push_back(b.shape(ax));
}
a = transpose_reshape(a, a_batch, a_concat, a_contract);
b = transpose_reshape(b, b_batch, b_contract, b_concat);
return reshape(matmul(a, b, s), std::move(out_shape), s);
}
// Collapse repeated subscripts and return the resulting array. The subscript
// is also updated in place. For example:
// - Given an input with shape (4, 4) and subscript "ii", returns
// the diagonal of shape (4,) and updates the subscript to "i".
// - Given an input with shape (4, 2, 4, 2) and subscript "ijij",
// returns an output with shape (4, 2) and updates the subscript
// to "ij".
array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
// Build a list of (repeat chars, num repeats)
auto& str = subscript.str;
std::vector<std::pair<char, int>> repeats;
std::string new_str;
{
std::string repeat_str;
std::string no_repeat_str;
std::unordered_map<char, int> counts;
for (int i = 0; i < str.size(); ++i) {
auto [it, _] = counts.insert({str[i], 0});
it->second++;
}
for (auto& v : counts) {
if (v.second > 1) {
repeats.emplace_back(v.first, v.second);
repeat_str += v.first;
}
}
for (auto& c : str) {
if (counts[c] == 1) {
no_repeat_str += c;
}
}
new_str = repeat_str + no_repeat_str;
}
// Build the inputs for gather
auto slice_sizes = in.shape();
std::vector<int> axes;
std::vector<array> indices;
int n_expand = repeats.size();
for (auto [c, v] : repeats) {
for (int i = 0; i < str.size(); ++i) {
if (str[i] == c) {
slice_sizes[i] = 1;
axes.push_back(i);
}
}
std::vector<int> idx_shape(n_expand--, 1);
idx_shape[0] = in.shape(axes.back());
auto idx = reshape(arange(in.shape(axes.back()), s), idx_shape, s);
for (int i = 0; i < v; ++i) {
indices.push_back(idx);
}
}
in = gather(in, indices, axes, slice_sizes, s);
// Update subscript string with removed dups
str = new_str;
// Squeeze singleton dimensions left over from the gather
for (auto& ax : axes) {
ax += indices[0].ndim();
}
return squeeze(in, axes, s);
}
// Collapse repeat indices and sum single dimensions.
// For example:
// - "aa" becomes "a"
// - "ij,jk->k" becoms "j,jk->k"
void preprocess_einsum_inputs(
std::vector<Subscript>& inputs,
const Subscript& output,
const std::vector<int>& positions,
std::vector<array>& operands,
StreamOrDevice s) {
// Collapse repeat indices
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
if (in.set.size() < in.str.size()) {
operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s);
}
}
// Sum indices that are only in a single input
{
std::unordered_map<char, int> counts;
for (auto& in : inputs) {
for (auto c : in.set) {
auto inserted = counts.insert({c, 0});
inserted.first->second++;
}
}
for (auto c : output.set) {
auto inserted = counts.insert({c, 0});
inserted.first->second++;
}
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
std::vector<int> sum_axes;
for (int ax = 0; ax < in.str.size(); ++ax) {
if (counts[in.str[ax]] == 1) {
sum_axes.push_back(ax);
}
}
if (!sum_axes.empty()) {
operands[positions[i]] =
sum(operands[positions[i]], sum_axes, false, s);
}
for (auto it = sum_axes.rbegin(); it != sum_axes.rend(); ++it) {
in.set.erase(in.str[*it]);
in.str.erase(in.str.begin() + *it);
}
}
}
}
array einsum_naive(
std::vector<Subscript> inputs,
const Subscript& output,
const std::vector<int>& positions,
std::vector<array> operands,
StreamOrDevice s) {
// Map each character to an axis
std::unordered_map<char, int> char_to_ax;
for (auto& in : inputs) {
for (auto c : in.str) {
char_to_ax.insert({c, char_to_ax.size()});
}
}
// Expand and transpose inputs as needed
for (int i = 0; i < inputs.size(); ++i) {
int pos = positions[i];
auto& op = operands[pos];
// Add missing dimensions at the end
if (op.ndim() != char_to_ax.size()) {
auto shape = op.shape();
shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1);
op = reshape(op, std::move(shape), s);
}
// Transpose:
// - Build a vector of (char, ax) pairs for the current input
// - Sort the vector by the canonical axis in char_to_ax
// - Extract the sorted axis to get transpose order
std::vector<std::pair<char, int>> str_ax;
for (auto c : inputs[i].str) {
str_ax.emplace_back(c, str_ax.size());
}
for (auto [c, ax] : char_to_ax) {
if (inputs[i].set.find(c) == inputs[i].set.end()) {
str_ax.emplace_back(c, str_ax.size());
}
}
std::sort(
str_ax.begin(),
str_ax.end(),
[&char_to_ax](const auto& x, const auto& y) {
return char_to_ax[x.first] < char_to_ax[y.first];
});
// Skip the transpose if not needed
if (std::is_sorted(
str_ax.begin(), str_ax.end(), [](const auto& x, const auto& y) {
return x.second < y.second;
})) {
continue;
}
std::vector<int> reorder;
for (auto [c, ax] : str_ax) {
reorder.push_back(ax);
}
op = transpose(op, reorder, s);
}
// Multiply and sum
auto out = operands[positions[0]];
for (int i = 1; i < positions.size(); ++i) {
out = multiply(out, operands[positions[i]], s);
}
std::vector<int> sum_axes;
for (auto [c, ax] : char_to_ax) {
if (output.set.find(c) == output.set.end()) {
sum_axes.push_back(ax);
}
}
if (!sum_axes.empty()) {
out = sum(out, sum_axes, false, s);
}
// Transpose output if needed
std::vector<int> reorder;
for (auto c : output.str) {
reorder.push_back(char_to_ax[c]);
}
for (auto& r : reorder) {
int offset = 0;
for (auto s : sum_axes) {
if (r > s) {
offset++;
}
}
r -= offset;
}
return transpose(out, reorder, s);
}
std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
const std::string& subscripts,
const std::vector<array>& operands,
const std::string& fn_name) {
if (operands.size() == 0) {
std::ostringstream msg;
msg << "[" << fn_name << "] At least one operand is required.";
throw std::invalid_argument(msg.str());
}
auto [in_subscripts, out_subscript] = parse(subscripts);
if (operands.size() != in_subscripts.size()) {
std::ostringstream msg;
msg << "[" << fn_name << "] Number of operands, " << operands.size()
<< ", does not match number of input subscripts, "
<< in_subscripts.size();
throw std::invalid_argument(msg.str());
}
auto check_letters = [&](const auto& subscript) {
for (auto c : subscript) {
if (!isalpha(c)) {
std::ostringstream msg;
msg << "[" << fn_name << "] Subscripts must be letters, but got '" << c
<< "'.";
throw std::invalid_argument(msg.str());
}
}
};
for (auto& in : in_subscripts) {
check_letters(in);
}
check_letters(out_subscript);
CharSet out_set(out_subscript.begin(), out_subscript.end());
if (out_set.size() != out_subscript.size()) {
std::ostringstream msg;
msg << "[" << fn_name << "] Repeat indices not allowed in output.";
throw std::invalid_argument(msg.str());
}
Subscript output(out_subscript, std::move(out_set));
std::unordered_map<char, int> dim_map;
std::vector<Subscript> inputs;
for (int i = 0; i < in_subscripts.size(); ++i) {
auto& in = in_subscripts[i];
CharSet in_set(in.begin(), in.end());
inputs.emplace_back(in, in_set);
if (in.size() != operands[i].ndim()) {
std::ostringstream msg;
msg << "[" << fn_name << "] Invalid number of subscripts " << in.size()
<< " for input " << i << " with " << operands[i].ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
// Check repeat subscripts are valid
if (in_set.size() < in.size()) {
std::unordered_map<char, int> local_dims;
for (int j = 0; j < in.size(); ++j) {
auto dim = operands[i].shape(j);
auto inserted = local_dims.insert({in[j], dim});
if (!inserted.second) {
if (inserted.first->second != dim) {
std::ostringstream msg;
msg << "[" << fn_name << "] Dimensions of repeated subscripts "
<< "do not have the same size (" << inserted.first->second
<< " != " << dim << ").";
throw std::invalid_argument(msg.str());
}
}
}
}
for (int j = 0; j < in.size(); j++) {
auto c = in[j];
auto dim = operands[i].shape(j);
auto inserted = dim_map.insert({c, dim});
auto& in_dim = inserted.first->second;
if (dim != 1 && in_dim != 1 && in_dim != dim) {
std::ostringstream msg;
msg << "[" << fn_name << "] Cannot broadcast dimension " << j
<< " of input " << i << " with shape " << operands[i].shape()
<< " to size " << in_dim << ".";
throw std::invalid_argument(msg.str());
}
// Ensure the broadcasted size is used
in_dim = std::max(in_dim, dim);
}
}
size_t max_size = term_size(out_subscript, dim_map);
for (auto& in : in_subscripts) {
max_size = std::max(max_size, term_size(in, dim_map));
}
PathInfo path_info;
// Get the full naive cost
std::tie(path_info.naive_cost, path_info.naive_scaling) =
compute_cost_and_scaling(inputs, output, dim_map);
// Calculate the path
std::vector<PathNode> path;
if (inputs.size() <= 2) {
std::vector<int> positions(in_subscripts.size());
std::iota(positions.begin(), positions.end(), 0);
path.emplace_back(
std::move(inputs), std::move(output), std::move(positions));
} else {
std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) =
greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size);
// Set the final output subscript to the actual output
path.back().output = std::move(output);
}
return {path, path_info};
}
} // namespace
std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
const std::string& subscripts,
const std::vector<array>& operands) {
auto [path, path_info] =
einsum_path_helper(subscripts, operands, "einsum_path");
std::vector<std::vector<int>> pos_path;
for (auto& p : path) {
pos_path.push_back(p.positions);
}
std::ostringstream path_print;
path_print << " Complete contraction: " << subscripts << "\n"
<< " Naive scaling: " << path_info.naive_scaling << "\n"
<< " Optimized scaling: " << path_info.optimized_scaling
<< "\n"
<< " Naive FLOP count: " << path_info.naive_cost << "\n"
<< " Optimized FLOP count: " << path_info.optimized_cost << "\n";
// TODO add more info here
return {pos_path, path_print.str()};
}
array einsum(
const std::string& subscripts,
const std::vector<array>& operands,
StreamOrDevice s /* = {} */) {
auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum");
auto inputs = operands;
for (auto& node : path) {
preprocess_einsum_inputs(
node.inputs, node.output, node.positions, inputs, s);
if (can_dot(node.inputs, node.output)) {
auto& in_a = node.inputs[0];
auto& in_b = node.inputs[1];
auto& out = node.output;
std::vector<int> a_contract;
std::vector<int> a_batch;
std::vector<int> a_concat;
for (int i = 0; i < in_a.str.size(); ++i) {
auto c = in_a.str[i];
if (out.set.find(c) == out.set.end()) {
// Not in the output, contraction
a_contract.push_back(i);
} else if (in_b.set.find(c) != in_b.set.end()) {
// Not a contraction but in both inputs, batch dim
a_batch.push_back(i);
} else {
// Not a batch dim or contract dim, so concat dim
a_concat.push_back(i);
}
}
std::vector<int> b_contract;
std::vector<int> b_batch;
std::vector<int> b_concat;
for (auto a_i : a_contract) {
b_contract.push_back(in_b.str.find(in_a.str[a_i]));
}
for (auto a_i : a_batch) {
b_batch.push_back(in_b.str.find(in_a.str[a_i]));
}
for (int i = 0; i < in_b.str.size(); ++i) {
auto c = in_b.str[i];
if (out.set.find(c) != out.set.end() &&
in_a.set.find(c) == in_a.set.end()) {
b_concat.push_back(i);
}
}
auto& a = inputs[node.positions[0]];
auto& b = inputs[node.positions[1]];
std::unordered_map<char, int> char_map;
for (auto i : a_batch) {
char_map.insert({in_a.str[i], char_map.size()});
}
for (auto i : a_concat) {
char_map.insert({in_a.str[i], char_map.size()});
}
for (auto i : b_concat) {
char_map.insert({in_b.str[i], char_map.size()});
}
inputs.emplace_back(batch_tensordot(
a,
b,
std::move(a_contract),
std::move(a_batch),
std::move(a_concat),
std::move(b_contract),
std::move(b_batch),
std::move(b_concat),
s));
std::vector<int> reorder;
for (auto c : node.output.str) {
reorder.push_back(char_map[c]);
}
inputs.back() = transpose(inputs.back(), reorder, s);
} else {
inputs.emplace_back(
einsum_naive(node.inputs, node.output, node.positions, inputs, s));
}
// Positions are always sorted increasing, so start from the back
for (auto it = node.positions.rbegin(); it != node.positions.rend(); ++it) {
inputs.erase(inputs.begin() + *it);
}
}
return inputs.front();
}
} // namespace mlx::core

22
mlx/einsum.h Normal file
View File

@@ -0,0 +1,22 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <string>
#include <tuple>
#include <vector>
#include "mlx/array.h"
#include "mlx/utils.h"
namespace mlx::core {
std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
const std::string& subscripts,
const std::vector<array>& operands);
array einsum(
const std::string& subscripts,
const std::vector<array>& operands,
StreamOrDevice s = {});
} // namespace mlx::core

View File

@@ -18,7 +18,7 @@ std::vector<array> Custom::vjp(
auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents); auto [_, vjps] = mlx::core::vjp(fallback_, primals, cotangents);
std::vector<array> vjp_outs; std::vector<array> vjp_outs;
for (int i = 0, j = 0; i < vjps.size(); ++i) { for (int i = 0, j = 0; i < vjps.size(); ++i) {
if (i < argnums.size() && i == argnums[j]) { if (j < argnums.size() && i == argnums[j]) {
vjp_outs.push_back(vjps[i]); vjp_outs.push_back(vjps[i]);
j++; j++;
} }
@@ -30,15 +30,16 @@ std::vector<array> Custom::jvp(
const std::vector<array>& primals, const std::vector<array>& primals,
const std::vector<array>& tangents, const std::vector<array>& tangents,
const std::vector<int>& argnums) { const std::vector<int>& argnums) {
auto [_, jvps] = mlx::core::jvp(fallback_, primals, tangents); std::vector<array> all_tangents;
std::vector<array> jvp_outs; for (int i = 0, j = 0; i < primals.size(); i++) {
for (int i = 0, j = 0; i < jvps.size(); ++i) { if (j < argnums.size() && i == argnums[j]) {
if (i < argnums.size() && i == argnums[j]) { all_tangents.emplace_back(tangents[j++]);
jvp_outs.push_back(jvps[i]); } else {
j++; all_tangents.emplace_back(zeros_like(primals[i]));
} }
} }
return jvp_outs; auto [_, jvps] = mlx::core::jvp(fallback_, primals, all_tangents);
return jvps;
} }
std::pair<std::vector<array>, std::vector<int>> Custom::vmap( std::pair<std::vector<array>, std::vector<int>> Custom::vmap(
@@ -609,4 +610,253 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_; return needs_mask_ == a_other.needs_mask_ && scale_ == a_other.scale_;
} }
array pack_and_quantize(
array& packed_w,
const array& scales,
const array& biases,
int group_size,
int bits,
const Stream& s) {
int el_per_int = 32 / bits;
array zero(0, packed_w.dtype());
array n_bins((1 << bits) - 1, packed_w.dtype()); // 2**bits - 1
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
packed_w = astype(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
zero,
n_bins),
uint32);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
return packed_w;
}
std::tuple<array, array, array>
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
auto s = to_stream(s_);
if (group_size != 32 && group_size != 64 && group_size != 128) {
std::ostringstream msg;
msg << "[quantize] The requested group size " << group_size
<< " is not supported. The supported group sizes are 64 and 128.";
throw std::invalid_argument(msg.str());
}
if (bits != 2 && bits != 4 && bits != 8) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. The supported bits are 2, 4 and 8.";
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
if ((w.shape(-1) % group_size) != 0) {
std::ostringstream msg;
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
<< "the quantization group size " << group_size
<< ". However the provided " << " matrix has shape " << w.shape();
throw std::invalid_argument(msg.str());
}
int el_per_int = 32 / bits;
if (w.shape(-1) < 32 * el_per_int) {
std::ostringstream msg;
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
<< "too small for quantization. We support >=512 for 2 bits, "
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
<< "shape " << w.shape() << ".";
throw std::invalid_argument(msg.str());
}
auto fallback = [group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto wshape = w.shape();
wshape.back() = -1;
array zero(0, w.dtype());
array n_bins((1 << bits) - 1, w.dtype()); // 2**bits - 1
array eps(1e-7, w.dtype());
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales =
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
scales = where(mask, scales, negative(scales), s);
array edge = where(mask, w_min, w_max, s);
array q0 = round(divide(edge, scales, s), s);
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return {
reshape(packed_w, wshape, s),
reshape(scales, wshape, s),
reshape(biases, wshape, s),
};
};
std::vector<array> outputs;
if (s.device == Device::gpu) {
auto wq_shape = w.shape();
wq_shape.back() = w.shape(-1) / el_per_int;
auto sshape = w.shape();
sshape.back() = w.shape(-1) / group_size;
outputs = array::make_arrays(
{wq_shape, sshape, sshape},
{uint32, w.dtype(), w.dtype()},
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w});
} else {
outputs = fallback({w});
}
return {outputs[0], outputs[1], outputs[2]};
}
array affine_quantize(
const array& w,
const array& scales,
const array& biases,
int group_size,
int bits,
StreamOrDevice s_) {
auto s = to_stream(s_);
int el_per_int = 32 / bits;
auto fallback = [group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto scales = expand_dims(inputs[1], -1, s);
auto biases = expand_dims(inputs[2], -1, s);
auto wshape = w.shape();
wshape.back() = -1;
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return {reshape(packed_w, wshape, s)};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) / el_per_int;
return array(
out_shape,
uint32,
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, false),
{w, scales, biases});
}
return fallback({w, scales, biases})[0];
}
array affine_dequantize(
const array& w,
const array& scales,
const array& biases,
int group_size,
int bits,
StreamOrDevice s_) {
if (bits <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for bits: " << bits;
throw std::invalid_argument(msg.str());
}
if (group_size <= 0) {
std::ostringstream msg;
msg << "[dequantize] Invalid value for group_size: " << group_size;
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << ".";
throw std::invalid_argument(msg.str());
}
auto wshape = w.shape();
auto sshape = scales.shape();
auto bshape = biases.shape();
wshape.back() = -1;
sshape.back() = -1;
bshape.back() = -1;
if (wshape != sshape || wshape != bshape) {
throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix");
}
if (w.dtype() != uint32) {
throw std::invalid_argument(
"[dequantize] The matrix should be given as a uint32");
}
// Packing into uint32
int el_per_int = 32 / bits;
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[dequantize] Shape of scales and biases does not match the matrix "
<< "given the quantization parameters. Provided matrix of shape "
<< w.shape() << " and scales/biases of shape " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits << ".";
throw std::invalid_argument(msg.str());
}
auto s = to_stream(s_);
auto fallback =
[&wshape, &sshape, &scales, &biases, group_size, bits, el_per_int, s](
const std::vector<array>& inputs) -> std::vector<array> {
auto& w = inputs[0];
auto& scales = inputs[1];
auto& biases = inputs[2];
std::vector<array> parts;
for (int start = 0; start < 32; start += bits) {
int shift_left = 32 - (start + bits);
int shift_right = shift_left + start;
parts.push_back(expand_dims(
right_shift(
left_shift(w, array(32 - (start + bits), uint32), s),
array(32 - bits, uint32),
s),
-1,
s));
}
array w_full = concatenate(parts, -1, s);
// Dequantize
wshape.push_back(group_size);
w_full = reshape(w_full, wshape, s);
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
w_full = add(w_full, expand_dims(biases, -1, s), s);
w_full = reshape(w_full, sshape, s);
return {w_full};
};
if (s.device == Device::gpu) {
auto out_shape = w.shape();
out_shape.back() = w.shape(-1) * el_per_int;
return array(
out_shape,
scales.dtype(),
std::make_shared<AffineQuantize>(s, fallback, group_size, bits, true),
{w, scales, biases});
}
return fallback({w, scales, biases})[0];
}
} // namespace mlx::core::fast } // namespace mlx::core::fast

Some files were not shown because too many files have changed in this diff Show More