Compare commits

..

25 Commits

Author SHA1 Message Date
Awni Hannun
079882495d version bump (#1172) 2024-05-31 12:29:12 -07:00
K Venkat Ramnan
ab977109db feat: Added dlpack device (#1165)
* feat: Added dlpack device

* feat: Added device_id to dlpack device

* feat: Added device_id to dlpack device

* doc: updated conversion docs

* doc: updated numpy.rst dlpack information

* doc: updated numpy.rst dlpack information

* Update docs/src/usage/numpy.rst

* Update docs/src/usage/numpy.rst

---------

Co-authored-by: Venkat Ramnan Kalyanakumar <venkatramnankalyanakumar@Venkats-MacBook-Air.local>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-05-31 12:29:01 -07:00
Awni Hannun
fd1c08137b stable cumprod grad at 0 (#1167) 2024-05-31 12:28:42 -07:00
Jagrit Digani
76b6cece46 Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management

* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d Fix matvec vector stride bug (#1168) 2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1 Fix a couple bugs (#1161)
* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
a87ef5bfc1 fix broadcast bug in bitwise ops (#1157) 2024-05-24 11:44:40 -07:00
Awni Hannun
9f9cb7a2ef version bump (#1154) 2024-05-23 18:08:08 -07:00
Awni Hannun
7e26fd8032 Option to JIT steel gemm / conv (#1139) 2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67 Float mask update (#1152)
* Float mask update

* Update CPU impl
2024-05-23 17:20:44 -07:00
Angelos Katharopoulos
50dfb664db Comms (#1097)
* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
2024-05-23 17:04:02 -07:00
Awni Hannun
0189ab6ab6 More jitting (#1132)
* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
2024-05-23 16:23:44 -07:00
Rifur13
9401507336 Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
2024-05-22 20:01:44 -07:00
Awni Hannun
eb8321d863 list based indexing (#1150) 2024-05-22 15:52:05 -07:00
Abe Leininger
79ef49b2c2 add mx.trace (#1143) (#1147)
* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
e110ca11e2 Fix offset bug for device buffers (#1151)
* fix bug with large offsets for buffers

* add a test

* remove test as its too big for small machine
2024-05-22 15:50:05 -07:00
Awni Hannun
226748b3e7 JIT compile option for binary minimization (#1091)
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
2024-05-22 12:57:13 -07:00
Awni Hannun
d568c7ee36 Rename block sparse (#1149)
* block_sparse_mm to gather_mm

* rename

* nit

* nit
2024-05-22 07:48:34 -07:00
Awni Hannun
e6fecbb3e1 Some fixes in docs (#1141)
* fixes in docs

* nit
2024-05-20 11:51:47 -07:00
Angelos Katharopoulos
da83f899bb Improve qvm speed (#1140) 2024-05-20 09:20:44 -07:00
jlwitthuhn
7e5674d8be Treate 'minimum' differently in cosine decay (#1138) 2024-05-20 08:00:48 -07:00
Shixian Sheng
0a558577bf Update README.md (#1136) 2024-05-20 06:16:40 -07:00
Awni Hannun
fb71a82ada Fix copy bug with many dims (#1137) 2024-05-17 21:10:03 -07:00
Awni Hannun
23406c9e9e Choose the right MLX bf16 for extensions (#1135)
* default to custom bf

* choose right bf

* fix extensions

* fix circle conf
2024-05-17 15:09:28 -07:00
Luca Arnaboldi
b3ec792380 Implemented Cholesky on CPU (#1119) 2024-05-17 12:31:59 -07:00
167 changed files with 11099 additions and 6677 deletions

View File

@@ -71,6 +71,7 @@ jobs:
name: Install dependencies
command: |
brew install python@3.8
brew install openmpi
python3.8 -m venv env
source env/bin/activate
pip install --upgrade pip
@@ -96,10 +97,14 @@ jobs:
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run:
name: Build example extension
command: |
cd examples/extensions && python3.8 -m pip install .
source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:
@@ -111,7 +116,13 @@ jobs:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
DEVICE=cpu ./build/tests/tests
- run:
name: Build small binary
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
make -j
build_release:
parameters:

View File

@@ -16,6 +16,7 @@ MLX was developed with contributions from the following individuals:
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -20,10 +20,11 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.13.1)
set(MLX_VERSION 0.14.1)
endif()
# --------------------- Processor tests -------------------------
@@ -109,7 +110,7 @@ elseif (MLX_BUILD_METAL)
$<INSTALL_INTERFACE:include/metal_cpp>
)
target_link_libraries(
mlx
mlx PUBLIC
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
@@ -122,7 +123,7 @@ if (MLX_BUILD_CPU)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
@@ -145,7 +146,7 @@ if (MLX_BUILD_CPU)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
@@ -160,12 +161,17 @@ if (MLX_BUILD_CPU)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
endif()
else()
set(MLX_BUILD_ACCELERATE OFF)
endif()
find_package(MPI)
if (MPI_FOUND)
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
@@ -175,6 +181,14 @@ target_include_directories(
$<INSTALL_INTERFACE:include>
)
FetchContent_Declare(fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL
)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)

View File

@@ -88,13 +88,13 @@ for more information on building the C++ and Python APIs from source.
## Contributing
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
information on building from source, and running tests.
We are grateful for all of [our
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to the list in your
pull request.

View File

@@ -28,11 +28,11 @@ def bench(f, a, b):
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding)
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
@@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding)
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
@@ -53,11 +53,13 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
@@ -67,15 +69,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding)
f_pt = make_pt_conv_2D(strides, padding)
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
@@ -84,7 +86,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
@@ -95,35 +97,40 @@ if __name__ == "__main__":
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
for N, H, W, C, kH, kW, O, strides, padding in shapes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, np_dtype
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -163,6 +163,8 @@ should point to the path to the built metal library.
- ON
* - MLX_BUILD_GGUF
- ON
* - MLX_METAL_JIT
- OFF
.. note::
@@ -184,21 +186,30 @@ should point to the path to the built metal library.
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
and `BUILD_SHARED_LIBS=ON`.
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
```shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=ON \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF
```
.. code-block:: shell
cmake ..
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists accross reboots.
Troubleshooting
^^^^^^^^^^^^^^^

View File

@@ -10,5 +10,6 @@ Linear Algebra
inv
norm
cholesky
qr
svd

View File

@@ -35,7 +35,6 @@ Operations
bitwise_or
bitwise_xor
block_masked_mm
block_sparse_mm
broadcast_to
ceil
clip
@@ -69,6 +68,8 @@ Operations
floor
floor_divide
full
gather_mm
gather_qmm
greater
greater_equal
identity
@@ -149,6 +150,7 @@ Operations
tensordot
tile
topk
trace
transpose
tri
tril

View File

@@ -3,7 +3,11 @@
Conversion to NumPy and Other Frameworks
========================================
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
Let's convert an array to NumPy and back.
.. code-block:: python

View File

@@ -9,3 +9,4 @@ build_example(tutorial.cpp)
build_example(linear_regression.cpp)
build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)
build_example(distributed.cpp)

View File

@@ -0,0 +1,22 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
if (!distributed::is_available()) {
std::cout << "No communication backend found" << std::endl;
return 1;
}
auto global_group = distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10});
array out = distributed::all_reduce_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@@ -21,4 +21,4 @@ python setup.py build_ext -j8 --inplace
```
python test.py
`
```

View File

@@ -25,6 +25,7 @@ else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)

View File

@@ -32,8 +32,6 @@ DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
@@ -49,6 +47,8 @@ DEFAULT(ErfInv)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Less)
@@ -80,6 +80,7 @@ DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);

View File

@@ -56,6 +56,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)

View File

@@ -1,6 +1,8 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"

View File

@@ -0,0 +1,101 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
// (A)ᵀ = A
// and that a column-major lower triangular matrix is a row-major upper
// triangular matrix, so uplo is the opposite of what we would expect from
// upper
char uplo = (upper) ? 'L' : 'U';
// The decomposition is computed in place, so just copy the input to the
// output.
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
float* matrix = factor.data<float>();
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[cholesky] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
}
matrix += N;
}
}
}
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32.");
}
cholesky_impl(inputs[0], output, upper_);
}
} // namespace mlx::core

View File

@@ -111,13 +111,17 @@ void slow_conv_2D(
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int C = in.shape(3); // In channels
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
const int groups = C / wt.shape(3);
const int C_per_group = wt.shape(3);
const int O_per_group = O / groups;
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_W = in.strides()[2];
@@ -141,33 +145,35 @@ void slow_conv_2D(
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ww
} // wh
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
@@ -219,41 +225,43 @@ void slow_conv_2D(
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ih, iw check
} // ww
} // wh
} // ih, iw check
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int oH_border_0 = 0;

View File

@@ -256,7 +256,7 @@ void copy_general_general(
}
int size = std::accumulate(
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) {
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);

View File

@@ -43,8 +43,8 @@ DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
@@ -113,6 +113,7 @@ DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
namespace {

View File

@@ -28,6 +28,7 @@ const char* get_kernel_preamble() {
return R"preamble(
$INCLUDES
$CONTENT
using namespace mlx::core;
using namespace mlx::core::detail;
)preamble";
}

View File

@@ -17,24 +17,25 @@ namespace mlx::core {
namespace {
template <typename T>
template <typename T, typename mask_t>
inline void mask_matrix(
T* data,
const bool* mask,
const mask_t* mask,
int block_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
const size_t Y_mask_str,
const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
if (do_mask != 1) {
int loc_x = i * block_size;
int loc_y = j * block_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
@@ -43,7 +44,11 @@ inline void mask_matrix(
int size_y = std::min(block_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
if constexpr (std::is_same_v<mask_t, bool>) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
} else {
data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
}
}
}
}
@@ -62,36 +67,39 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto& out_mask = inputs[2];
auto check_transpose = [](const array& arr, bool do_copy) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
}
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
}
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto check_transpose =
[](const array& arr, bool do_copy, bool expand_all = false) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
}
return std::make_tuple(false, stx, arr);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
}
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
bool has_op_mask = inputs.size() > 3;
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
auto [a_transposed, lda, a] =
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
auto [b_transposed, ldb, b] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
@@ -114,27 +122,42 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
int Y,
size_t X_data_str,
size_t Y_data_str) {
const bool* mask_ptr = mask.data<bool>() +
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
return mask_matrix(
data,
mask_ptr,
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str);
if (mask.dtype() == bool_) {
return mask_matrix(
data,
mask.data<bool>(),
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str,
mask_offset);
} else {
return mask_matrix(
data,
mask.data<float>(),
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str,
mask_offset);
}
};
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
// Adjust pointer
float* ai =
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
@@ -144,7 +167,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
// Zero out blocks in a and b if needed
if (has_op_mask) {
auto& a_mask = inputs[3];
auto& a_mask = inputs[inputs.size() - 2];
mask_array(
a_mask,
ai,
@@ -155,7 +178,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
a_transposed ? 1 : lda,
a_transposed ? lda : 1);
auto& b_mask = inputs[4];
auto& b_mask = inputs[inputs.size() - 1];
mask_array(
b_mask,
bi,
@@ -186,14 +209,16 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
);
// Zero out blocks in out
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
if (has_out_mask) {
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
}
}
}
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockSparseMM::eval] Currently only supports float32.");
"[GatherMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -277,4 +302,4 @@ void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
}
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -357,7 +357,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
void BlockSparseQMM::eval(const std::vector<array>& inputs, array& out) {
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];

View File

@@ -1,33 +1,125 @@
add_custom_command(
OUTPUT compiled_preamble.cpp
function(make_jit_source SRC_FILE)
# This function takes a metal header file,
# runs the C preprocessesor on it, and makes
# the processed contents available as a string in a C++ function
# mlx::core::metal::${SRC_NAME}()
#
# To use the function, declare it in jit/includes.h and
# include jit/includes.h.
#
# Additional arguments to this function are treated as dependencies
# in the Cmake build system.
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
add_custom_command(
OUTPUT jit/${SRC_NAME}.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_CURRENT_BINARY_DIR}/jit
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
${SRC_FILE}
"-D${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h
kernels/unary.h
kernels/binary.h
kernels/bf16.h
kernels/erf.h
kernels/expm1f.h
kernels/utils.h
kernels/bf16_math.h
)
kernels/${SRC_FILE}.h
${ARGN}
)
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME})
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
)
endfunction(make_jit_source)
add_custom_target(
compiled_preamble
DEPENDS compiled_preamble.cpp
make_jit_source(
utils
kernels/bf16.h
kernels/complex.h
kernels/defines.h
)
make_jit_source(
unary_ops
kernels/erf.h
kernels/expm1f.h
)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(
reduce_utils
kernels/atomic.h
kernels/reduction/ops.h
)
make_jit_source(scatter)
make_jit_source(gather)
add_dependencies(mlx compiled_preamble)
if (MLX_METAL_JIT)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
)
make_jit_source(arange)
make_jit_source(copy)
make_jit_source(unary)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(ternary)
make_jit_source(softmax)
make_jit_source(scan)
make_jit_source(sort)
make_jit_source(
reduce
kernels/reduction/reduce_all.h
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
)
make_jit_source(
steel/gemm/gemm
kernels/steel/utils.h
kernels/steel/gemm/loader.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/params.h
kernels/steel/gemm/transforms.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(
steel/gemm/kernels/steel_gemm_masked
kernels/steel/defines.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(
steel/conv/conv
kernels/steel/utils.h
kernels/steel/defines.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/transforms.h
kernels/steel/conv/params.h
kernels/steel/conv/loader.h
kernels/steel/conv/loaders/loader_channel_l.h
kernels/steel/conv/loaders/loader_channel_n.h
)
make_jit_source(
steel/conv/kernels/steel_conv
)
make_jit_source(
steel/conv/kernels/steel_conv_general
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
)
endif()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
@@ -46,7 +138,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
)
if (NOT MLX_METAL_PATH)

View File

@@ -0,0 +1,322 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
void binary_op(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
compute_encoder.set_output_array(outputs[0], 2);
compute_encoder.set_output_array(outputs[1], 3);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void binary_op(
const std::vector<array>& inputs,
array& out,
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads =
bopt == BinaryOpType::General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "add");
}
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "arctan2");
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op(inputs, out, "bitwise_and");
break;
case BitwiseBinary::Or:
binary_op(inputs, out, "bitwise_or");
break;
case BitwiseBinary::Xor:
binary_op(inputs, out, "bitwise_xor");
break;
case BitwiseBinary::LeftShift:
binary_op(inputs, out, "left_shift");
break;
case BitwiseBinary::RightShift:
binary_op(inputs, out, "right_shift");
break;
}
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
binary_op(inputs, outputs, "divmod");
}
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "rem");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
}
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "ge");
}
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "geq");
}
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "le");
}
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "leq");
}
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "land");
}
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lor");
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lae");
}
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "max");
}
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min");
}
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "mul");
}
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "neq");
}
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "pow");
}
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "sub");
}
} // namespace mlx::core

View File

@@ -4,8 +4,8 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
@@ -56,12 +56,15 @@ inline void build_kernel(
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl
<< " constant const size_t* " << xname << "_strides [[buffer("
<< cnt++ << ")]]," << std::endl;
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
}
}
if (add_indices) {
os << " constant const size_t* in_strides [[buffer(" << cnt++
<< ")]],\n";
}
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
@@ -110,13 +113,17 @@ inline void build_kernel(
}
// Read the inputs in tmps
for (auto& x : inputs) {
int nc_in_count = 0;
for (int i = 0; i < inputs.size(); ++i) {
auto& x = inputs[i];
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
auto type_str = get_type_string(x.dtype());
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
print_constant(os, x);
os << ";" << std::endl;
os << ");" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
@@ -124,17 +131,20 @@ inline void build_kernel(
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
int offset = nc_in_count * ndim;
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << xname << "_strides[0]";
os << "index_0 * " << "in_strides[" << offset << "]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
}
os << "];" << std::endl;
nc_in_count++;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, " << xname
<< "_strides, ndim)];" << std::endl;
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
<< nc_in_count * ndim << ", ndim)];" << std::endl;
nc_in_count++;
}
}
@@ -190,7 +200,8 @@ void Compiled::eval_gpu(
// If not we have to build it ourselves
if (lib == nullptr) {
std::ostringstream kernel;
kernel << metal::get_kernel_preamble() << std::endl;
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops();
build_kernel(
kernel,
kernel_lib_ + "_contiguous",
@@ -295,6 +306,7 @@ void Compiled::eval_gpu(
// Put the inputs in
int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides
std::vector<size_t> in_strides;
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
@@ -302,13 +314,17 @@ void Compiled::eval_gpu(
auto& x = inputs[i];
compute_encoder.set_input_array(x, cnt++);
if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes(
strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t),
cnt++);
in_strides.insert(
in_strides.end(),
strides[stride_idx].begin(),
strides[stride_idx].end());
stride_idx++;
}
}
if (!in_strides.empty()) {
compute_encoder->setBytes(
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, true);

View File

@@ -1,9 +0,0 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* get_kernel_preamble();
}

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/matmul.h"
@@ -257,15 +258,19 @@ void implicit_gemm_conv_2D_gpu(
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
const int groups = conv_params.groups;
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
// Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
const int implicit_N = O_per_group;
const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
// Determine block and warp tiles
int wm = 2, wn = 2;
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
int bk = 16;
@@ -281,15 +286,15 @@ void implicit_gemm_conv_2D_gpu(
// Fix small channel specialization
int n_channel_specialization = 0;
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int channel_k_iters = ((C_per_group + bk - 1) / bk);
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
if (conv_params.C <= 2) {
if (C_per_group <= 2) {
gemm_k_iters = (implicit_K + bk - 1) / bk;
n_channel_specialization = conv_params.C;
} else if (conv_params.C <= 4) {
n_channel_specialization = C_per_group;
} else if (C_per_group <= 4) {
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
n_channel_specialization = conv_params.C;
n_channel_specialization = C_per_group;
}
bool small_filter = (!n_channel_specialization) &&
@@ -331,7 +336,17 @@ void implicit_gemm_conv_2D_gpu(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = get_steel_conv_kernel(
d,
kname.str(),
out,
bm,
bn,
bk,
wm,
wn,
n_channel_specialization,
small_filter);
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
@@ -340,7 +355,7 @@ void implicit_gemm_conv_2D_gpu(
size_t grid_dim_x = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
// Encode arrays
compute_encoder.set_input_array(in, 0);
@@ -484,7 +499,8 @@ void implicit_gemm_conv_2D_general_gpu(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
@@ -703,6 +719,7 @@ void conv_2D_gpu(
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
const int groups,
bool flip,
std::vector<array>& copies) {
// Make conv params
@@ -718,12 +735,12 @@ void conv_2D_gpu(
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
{in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
{wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
{out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
/* const int groups = */ groups,
/* const bool flip = */ flip,
};
@@ -735,6 +752,18 @@ void conv_2D_gpu(
bool channels_large = (conv_params.C + conv_params.O) >= 512;
bool channels_med = (conv_params.C + conv_params.O) >= 256;
if (groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
@@ -860,6 +889,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
copies);
}

View File

@@ -4,12 +4,14 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
@@ -62,27 +64,34 @@ void copy_gpu_inplace(
auto& strides_out_ = strides[1];
auto& d = metal::device(s.device);
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << "scopy";
break;
case CopyType::Vector:
kname << "vcopy";
break;
case CopyType::General:
kname << "gcopy";
break;
case CopyType::GeneralGeneral:
kname << "ggcopy";
break;
std::string kernel_name;
{
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << "s";
break;
case CopyType::Vector:
kname << "v";
break;
case CopyType::General:
kname << "g";
break;
case CopyType::GeneralGeneral:
kname << "gg";
break;
}
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
}
kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
kernel_name = kname.str();
}
kname << type_to_name(in) << type_to_name(out);
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto kernel = d.get_kernel(kname.str());
auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
@@ -106,7 +115,7 @@ void copy_gpu_inplace(
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
}

View File

@@ -285,7 +285,6 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
NS::Error* error = nullptr;
auto options = MTL::CompileOptions::alloc()->init();
options->setFastMathEnabled(false);
options->setLanguageVersion(get_metal_version());
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
options->release();

View File

@@ -63,7 +63,7 @@ struct CommandEncoder {
return enc;
}
void set_input_array(const array& a, int idx, int offset = 0) {
void set_input_array(const array& a, int idx, int64_t offset = 0) {
auto r_buf =
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
@@ -80,7 +80,7 @@ struct CommandEncoder {
enc->setBuffer(a_buf, base_offset, idx);
}
void set_output_array(array& a, int idx, int offset = 0) {
void set_output_array(array& a, int idx, int64_t offset = 0) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());

View File

@@ -1,24 +1,35 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include <fmt/format.h>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
constexpr int METAL_MAX_INDEX_ARRAYS = 20;
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
} // namespace
std::pair<std::string, std::string> make_index_args(
const std::string& idx_type,
int nidx) {
std::ostringstream idx_args;
std::ostringstream idx_arr;
for (int i = 0; i < nidx; ++i) {
idx_args << fmt::format(
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
idx_arr << fmt::format("idx{0}", i);
if (i < nidx - 1) {
idx_args << "\n";
idx_arr << ",";
}
}
return {idx_args.str(), idx_arr.str()};
}
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0];
@@ -42,15 +53,41 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::ostringstream kname;
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
if (idx_ndim <= 1) {
kname << "_" << idx_ndim;
{
std::ostringstream kname;
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
<< "_" << idx_ndim;
lib_name = kname.str();
kernel_name = lib_name;
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations
kernel_source << fmt::format(
gather_kernels,
type_to_name(out) + idx_type_name,
out_type_str,
idx_type_str,
nidx,
idx_args,
idx_arr,
idx_ndim);
lib = d.get_library(lib_name, kernel_source.str());
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder->setComputePipelineState(kernel);
size_t slice_size = 1;
@@ -102,8 +139,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
@@ -139,10 +176,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
// Get kernel name
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1);
@@ -159,32 +192,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
index_nd1_specialization &= inputs[i].flags().row_contiguous;
}
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string op_name;
switch (reduce_type_) {
case Scatter::None:
kname << "_none";
op_name = "none";
break;
case Scatter::Sum:
kname << "_sum";
op_name = "sum";
break;
case Scatter::Prod:
kname << "_prod";
op_name = "prod";
break;
case Scatter::Max:
kname << "_max";
op_name = "max";
break;
case Scatter::Min:
kname << "_min";
op_name = "min";
break;
}
kname << "_" << nidx;
{
std::ostringstream kname;
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
kname << "_" << op_name << "_" << nidx;
lib_name = kname.str();
kernel_name = kname.str();
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
std::string op_type;
switch (reduce_type_) {
case Scatter::None:
op_type = "None";
break;
case Scatter::Sum:
op_type = "Sum<{0}>";
break;
case Scatter::Prod:
op_type = "Prod<{0}>";
break;
case Scatter::Max:
op_type = "Max<{0}>";
break;
case Scatter::Min:
op_type = "Min<{0}>";
break;
}
if (reduce_type_ != Scatter::None) {
op_type = fmt::format(op_type, out_type_str);
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source << fmt::format(
scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str,
idx_type_str,
op_type,
nidx,
idx_args,
idx_arr);
lib = d.get_library(lib_name, kernel_source.str());
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kernel_name, lib);
auto& upd = inputs.back();
size_t nthreads = upd.size();
@@ -209,8 +296,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
@@ -279,8 +366,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid

View File

@@ -0,0 +1,9 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view arange_kernels = R"(
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
constant const {1}& start,
constant const {1}& step,
device {1}* out,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,87 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,98 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_two_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,100 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view copy_kernels = R"(
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg4_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg5_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg1_{0}")]] [[kernel]] void
copy_gg_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]);
template [[host_name("gg2_{0}")]] [[kernel]] void
copy_gg_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]);
template [[host_name("gg3_{0}")]] [[kernel]] void
copy_gg_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,34 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* utils();
const char* binary_ops();
const char* unary_ops();
const char* ternary_ops();
const char* reduce_utils();
const char* gather();
const char* scatter();
const char* arange();
const char* unary();
const char* binary();
const char* binary_two();
const char* copy();
const char* ternary();
const char* scan();
const char* softmax();
const char* sort();
const char* reduce();
const char* gemm();
const char* steel_gemm_fused();
const char* steel_gemm_masked();
const char* steel_gemm_splitk();
const char* conv();
const char* steel_conv();
const char* steel_conv_general();
} // namespace mlx::core::metal

View File

@@ -0,0 +1,81 @@
// Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}(
const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant size_t* idx_strides [[buffer(8)]],
const constant int& idx_ndim [[buffer(9)]],
{4}
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {{
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}>(
src,
out,
src_shape,
src_strides,
src_ndim,
slice_sizes,
axes,
idxs,
index,
grid_dim);
}}
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter_1d_index{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
}}
[[kernel]] void scatter{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]],
const constant size_t* idx_strides [[buffer(12)]],
const constant int& idx_ndim [[buffer(13)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}>(
updates,
out,
upd_shape,
upd_strides,
upd_ndim,
upd_size,
out_shape,
out_strides,
out_ndim,
axes,
idxs,
gid);
}}
)";

View File

@@ -0,0 +1,168 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view reduce_init_kernels = R"(
[[kernel]] void {0}(
device {1}* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {{
out[tid] = {2}<{1}>::init;
}}
)";
constexpr std::string_view reduce_kernels = R"(
template [[host_name("all_{0}")]] [[kernel]] void
all_reduce<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("colGeneral_{0}")]] [[kernel]] void
col_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
row_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";
constexpr std::string_view reduce_non_atomic_kernels = R"(
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]);
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -0,0 +1,26 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view scan_kernels = R"(
template [[host_name("contig_{0}")]] [[kernel]] void
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("strided_{0}")]] [[kernel]] void
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[thread_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]);
)";

View File

@@ -0,0 +1,23 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view softmax_kernels = R"(
template [[host_name("block_{0}")]] [[kernel]] void
softmax_single_row<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[thread_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("looped_{0}")]] [[kernel]] void
softmax_looped<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -0,0 +1,81 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view block_sort_kernels = R"(
template [[host_name("carg_{0}")]] [[kernel]] void
block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("ncarg_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("c_{0}")]] [[kernel]] void
block_sort<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("nc_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view multiblock_sort_kernels = R"(
template [[host_name("sort_{0}")]] [[kernel]] void
mb_block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {1}* out_vals [[buffer(1)]],
device {2}* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("partition_{0}")]] [[kernel]] void
mb_block_partition<{1}, {2}, true, {3}, {4}>(
device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals [[buffer(1)]],
const device {2}* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]);
template [[host_name("merge_{0}")]] [[kernel]] void
mb_block_merge<{1}, {2}, true, {3}, {4}>(
const device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals_in [[buffer(1)]],
const device {2}* dev_idxs_in [[buffer(2)]],
device {1}* dev_vals_out [[buffer(3)]],
device {2}* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";

View File

@@ -0,0 +1,32 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view steel_conv_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";
constexpr std::string_view steel_conv_general_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";

View File

@@ -0,0 +1,106 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view steel_gemm_fused_kernels = R"(
template [[host_name("{name}")]]
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
const device {itype} *A [[buffer(0)]],
const device {itype} *B [[buffer(1)]],
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
device {itype} *D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_masked_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
block_masked_gemm<
{itype},
{outmasktype},
{opmasktype},
{bm},
{bn},
{bk},
{wm},
{wn},
{trans_a},
{trans_b},
{mn_aligned},
{k_aligned}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device {outmasktype}* out_mask [[buffer(10)]],
const device {opmasktype}* lhs_mask [[buffer(11)]],
const device {opmasktype}* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_splitk_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk<
{itype},
{otype},
{bm},
{bn},
{bk},
{wm},
{wn},
{trans_a},
{trans_b},
{mn_aligned},
{k_aligned}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {otype}* C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk_accum<{atype}, {otype}>(
const device {atype}* C_split [[buffer(0)]],
device {otype}* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]);
)";
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk_accum_axpby<{atype}, {otype}>(
const device {atype}* C_split [[buffer(0)]],
device {otype}* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device {otype}* C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,80 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view ternary_kernels = R"(
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void
ternary_g_nd1<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void
ternary_g_nd2<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void
ternary_g_nd3<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 4>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
constant const size_t c_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 5>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
constant const size_t c_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,16 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view unary_kernels = R"(
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
device const {1}* in,
device {1}* out,
uint index [[thread_position_in_grid]]);
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
device const {1}* in,
device {1}* out,
device const int* in_shape,
device const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,489 @@
// Copyright © 2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/jit/ternary.h"
#include "mlx/backend/metal/jit/unary.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
using namespace fmt::literals;
namespace mlx::core {
std::string op_name(const array& arr) {
std::ostringstream op_t;
arr.primitive().print(op_t);
return op_t.str();
}
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source
<< metal::utils() << metal::arange()
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< fmt::format(
unary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
<< fmt::format(
binary_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two()
<< fmt::format(
binary_two_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
<< fmt::format(
ternary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::copy()
<< fmt::format(
copy_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
bool precise,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
softmax_kernels,
lib_name,
get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
bool reverse,
bool inclusive,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
scan_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out),
inclusive,
reverse);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
block_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
bn,
tn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_mb_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& idx,
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
multiblock_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
bn,
tn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< fmt::format(
reduce_init_kernels,
kernel_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(kernel_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& op_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
<< fmt::format(
non_atomic ? reduce_non_atomic_kernels
: reduce_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_type);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused()
<< fmt::format(
steel_gemm_fused_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
steel_gemm_splitk_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool axbpy) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels,
"name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
: "nomask_t";
auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_masked()
<< fmt::format(
steel_gemm_masked_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outmasktype"_a = out_mask_type,
"opmasktype"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn,
int n_channel_specialization,
bool small_filter) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format(
steel_conv_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"n_channels"_a = n_channel_specialization,
"small_filter"_a = small_filter);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general()
<< fmt::format(
steel_conv_general_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
} // namespace mlx::core

157
mlx/backend/metal/kernels.h Normal file
View File

@@ -0,0 +1,157 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core {
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
bool precise,
const array& out);
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
bool reverse,
bool inclusive,
const array& in,
const array& out);
MTL::ComputePipelineState* get_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
int bn,
int tn);
MTL::ComputePipelineState* get_mb_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& idx,
int bn,
int tn);
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& op_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn);
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned);
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool axbpy);
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned);
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn,
int n_channel_specialization,
bool small_filter);
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn);
} // namespace mlx::core

View File

@@ -1,26 +1,17 @@
set(
HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
bf16.h
bf16_math.h
complex.h
defines.h
utils.h
steel/conv/params.h
)
set(
KERNELS
"arange"
"arg_reduce"
"binary"
"binary_two"
"conv"
"copy"
"fft"
"gemv"
"quantized"
@@ -28,15 +19,45 @@ set(
"rms_norm"
"layer_norm"
"rope"
"scan"
"scaled_dot_product_attention"
)
if (NOT MLX_METAL_JIT)
set(
KERNELS
${KERNELS}
"arange"
"binary"
"binary_two"
"unary"
"ternary"
"copy"
"softmax"
"sort"
"ternary"
"unary"
"gather"
"scatter"
"scan"
"reduce"
)
set(
HEADERS
${HEADERS}
atomic.h
arange.h
unary_ops.h
unary.h
binary_ops.h
binary.h
ternary.h
copy.h
softmax.h
sort.h
scan.h
reduction/ops.h
reduction/reduce_init.h
reduction/reduce_all.h
reduction/reduce_col.h
reduction/reduce_row.h
)
endif()
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
@@ -68,23 +89,40 @@ foreach(KERNEL ${KERNELS})
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
endforeach()
file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal)
file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h)
foreach(KERNEL ${STEEL_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal)
file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h)
foreach(KERNEL ${REDUCE_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
if (NOT MLX_METAL_JIT)
set(
STEEL_KERNELS
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
)
set(
STEEL_HEADERS
steel/defines.h
steel/utils.h
steel/conv/conv.h
steel/conv/loader.h
steel/conv/loaders/loader_channel_l.h
steel/conv/loaders/loader_channel_n.h
steel/conv/loaders/loader_general.h
steel/conv/kernels/steel_conv.h
steel/conv/kernels/steel_conv_general.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h
)
foreach(KERNEL ${STEEL_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib

View File

@@ -0,0 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
template <typename T>
[[kernel]] void arange(
constant const T& start,
constant const T& step,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = start + index * step;
}

View File

@@ -1,15 +1,8 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
template <typename T>
[[kernel]] void arange(
constant const T& start,
constant const T& step,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = start + index * step;
}
#include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
@@ -18,7 +11,6 @@ template <typename T>
device type* out, \
uint index [[thread_position_in_grid]]);
// clang-format off
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t)

View File

@@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/utils.h"
@@ -194,4 +193,4 @@ instantiate_arg_reduce(int32, int32_t)
instantiate_arg_reduce(int64, int64_t)
instantiate_arg_reduce(float16, half)
instantiate_arg_reduce(float32, float)
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on

View File

@@ -4,7 +4,6 @@
#include <metal_atomic>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;

View File

@@ -6,9 +6,7 @@
using namespace metal;
// No support for less than metal 3.0
// anything greater has native bfloat
#ifndef METAL_3_0
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;

View File

@@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#ifndef METAL_3_0
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)

View File

@@ -1,273 +1,113 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2024 Apple Inc.
#pragma once
template <typename T, typename U, typename Op>
[[kernel]] void binary_ss(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
#include <metal_integer>
#include <metal_math>
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
}
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
}
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
}
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
device U* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf)
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
}
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return metal::precise::atan2(y, x);
}
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
device U* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}

View File

@@ -1,130 +1,24 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2024 Apple Inc.
#include <metal_integer>
#include <metal_math>
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_ss(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_sv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_vs(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_vv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd1(
device const T* a,
device const T* b,
device U* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd2(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd3(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_op_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g(
device const T* a,
device const T* b,
device U* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op, bopt) \
template \
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
[[host_name(name)]] [[kernel]] void binary_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op, dims>( \
template [[host_name("g" #dims name)]] [[kernel]] void \
binary_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -135,16 +29,16 @@ template <typename T, typename U, typename Op>
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name(name "_1")]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op>( \
template [[host_name("g1" name)]] [[kernel]] void \
binary_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op>( \
template [[host_name("g2" name)]] [[kernel]] void \
binary_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -152,8 +46,8 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op>( \
template [[host_name("g3" name)]] [[kernel]] void \
binary_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -162,30 +56,28 @@ template <typename T, typename U, typename Op>
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name("gn" name)]] [[kernel]] void binary_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
instantiate_binary_g(#name #tname, itype, otype, op) \
instantiate_binary_g_nd(#name #tname, itype, otype, op)
// clang-format off
#define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
@@ -194,22 +86,19 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
instantiate_binary_all(name, int64, int64_t, int64_t, op)
// clang-format off
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
// clang-format off
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op) // clang-format on
instantiate_binary_float(name, op)
// clang-format off
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
@@ -223,9 +112,8 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
instantiate_binary_all(name, complex64, complex64_t, bool, op)
// clang-format off
instantiate_binary_types(add, Add)
instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal)

View File

@@ -0,0 +1,296 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_integer>
#include <metal_math>
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct FloorDivide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf)
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
}
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return metal::precise::atan2(y, x);
}
};
struct DivMod {
template <typename T>
metal::array<T, 2> operator()(T x, T y) {
return {FloorDivide{}(x, y), Remainder{}(x, y)};
};
};

View File

@@ -0,0 +1,140 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U, typename Op>
[[kernel]] void binary_ss(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[0], b[0]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[0], b[index]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[0]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[index]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}

View File

@@ -1,212 +1,24 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2024 Apple Inc.
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
struct FloorDivide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_s2s(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[0]);
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_ss(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[0]);
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[index]);
d[index] = Op2()(a[0], b[index]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[index], b[0]);
d[index] = Op2()(a[index], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[index], b[index]);
d[index] = Op2()(a[index], b[index]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd1(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op1()(a[a_idx], b[b_idx]);
d[index] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd3(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2, int DIM>
[[kernel]] void binary_op_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
#define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] [[kernel]] void \
binary_op_##bopt<itype, otype, op1, op2>( \
binary_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op1, op2, dims>( \
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name("g" #dims name)]] [[kernel]] void \
binary_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -217,10 +29,9 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
template [[host_name(name "_1")]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op1, op2>( \
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name("g1" name)]] [[kernel]] void \
binary_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -228,8 +39,8 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op1, op2>( \
template [[host_name("g2" name)]] [[kernel]] void \
binary_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -238,8 +49,8 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op1, op2>( \
template [[host_name("g3" name)]] [[kernel]] void \
binary_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -248,12 +59,12 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op1, op2) \
template [[host_name(name)]] [[kernel]] void \
binary_op_g<itype, otype, op2, op2>( \
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name("gn" name)]] [[kernel]] void \
binary_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -265,33 +76,30 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g(#name #tname, itype, otype, op) \
instantiate_binary_g_nd(#name #tname, itype, otype, op)
// clang-format off
#define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \
instantiate_binary_all(name, float32, float, float, op1, op2) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
// clang-format off
#define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
instantiate_binary_float(name, op1, op2)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on
instantiate_binary_types(divmod, DivMod) // clang-format on

View File

@@ -1,7 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/binary.h"
#include "mlx/backend/metal/kernels/ternary.h"
#include "mlx/backend/metal/kernels/unary.h"
typedef half float16_t;

View File

@@ -109,6 +109,7 @@ template <typename T, int N>
bool valid = n < params->N;
// Unroll dimensions
int kernel_stride = 1;
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
@@ -125,7 +126,8 @@ template <typename T, int N>
oS /= params->oS[i];
wS /= params->wS[i];
out += ws_ * params->str[i];
out += ws_ * kernel_stride;
kernel_stride *= params->wS[i];
}
if (valid) {
@@ -648,4 +650,4 @@ winograd_conv_2d_output_transform(
// clang-format off
instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(float16, half); // clang-format on
instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@@ -0,0 +1,144 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U>
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}

View File

@@ -1,150 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename U>
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
@@ -152,92 +11,90 @@ template <typename T, typename U>
device otype* dst [[buffer(1)]], \
uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \
copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] [[kernel]] void \
copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name("g" #dims "_" name)]] [[kernel]] void \
copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg1_" name )]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("gg2_" name)]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("gg3_" name)]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
#define instantiate_copy_g(name, itype, otype) \
template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]);
// clang-format off
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("scopy" #tname, itype, otype, s) \
instantiate_copy("vcopy" #tname, itype, otype, v) \
instantiate_copy_g("gcopy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("s_copy" #tname, itype, otype, s) \
instantiate_copy("v_copy" #tname, itype, otype, v) \
instantiate_copy_g("copy" #tname, itype, otype) \
instantiate_copy_g_nd("copy" #tname, itype, otype)
// clang-format off
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \

View File

@@ -2,17 +2,14 @@
#pragma once
#ifdef __METAL__
#if defined __METAL__ || defined MLX_METAL_JIT
#define MTL_CONST constant
#else
#define MTL_CONST
#endif
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;

View File

@@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_math>
/*
@@ -67,4 +66,4 @@ float erfinv(float a) {
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}
}

View File

@@ -0,0 +1,45 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}

View File

@@ -1,173 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}
#define make_gather_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
[[kernel]] void gather( \
const device T* src [[buffer(0)]], \
device T* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \
const constant size_t* src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int* slice_sizes [[buffer(5)]], \
const constant int* axes [[buffer(6)]], \
const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
\
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \
out, \
src_shape, \
src_strides, \
src_ndim, \
slice_sizes, \
axes, \
idxs, \
index, \
grid_dim); \
}
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
make_gather(10)
/////////////////////////////////////////////////////////////////////
// Gather instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
gather<src_t, idx_t, nidx, nd>( \
const device src_t* src [[buffer(0)]], \
device src_t* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \
const constant size_t* src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int* slice_sizes [[buffer(5)]], \
const constant int* axes [[buffer(6)]], \
const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
// clang-format off
#define instantiate_gather4(name, src_t, idx_t, nidx) \
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
// clang-format off
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
// clang-format off
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t) // clang-format on

View File

@@ -1,13 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Indexing utils
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX>
struct Indices {
const array<const device IdxT*, NIDX> buffers;
@@ -24,31 +20,3 @@ METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
}
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
#define IDX_ARG_0(idx_t)
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
#define IDX_ARR_N(n) idx##n,
#define IDX_ARR_0()
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)

View File

@@ -601,14 +601,18 @@ METAL_FUNC void qvm_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 8;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = 32 / bits;
constexpr int tn = 32 / pack_factor;
constexpr int blocksize = SIMD_SIZE;
typedef float U;
typedef struct {
uint32_t wi[tn];
} vec_w;
thread uint32_t w_local;
thread U result[pack_factor] = {0};
thread vec_w w_local;
thread U result[tn * pack_factor] = {0};
thread U scale = 1;
thread U bias = 0;
thread U x_local = 0;
@@ -616,11 +620,12 @@ METAL_FUNC void qvm_impl(
// Adjust positions
const int out_vec_size_w = out_vec_size / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
w += out_col / pack_factor;
scales += out_col / group_size;
biases += out_col / group_size;
x += tid.y * in_vec_size;
int out_col =
tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
w += out_col / pack_factor + simd_lid * out_vec_size_w;
scales += out_col / group_size + simd_lid * out_vec_size_g;
biases += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.y * in_vec_size + simd_lid;
y += tid.y * out_vec_size + out_col;
if (out_col >= out_vec_size) {
@@ -628,40 +633,61 @@ METAL_FUNC void qvm_impl(
}
// Loop over in_vec in blocks of blocksize
int i = 0;
for (; i + blocksize <= in_vec_size; i += blocksize) {
x_local = x[i + simd_lid];
scale = scales[(i + simd_lid) * out_vec_size_g];
bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w];
int remaining = in_vec_size % blocksize;
if (remaining == 0) {
for (int i = 0; i < in_vec_size; i += blocksize) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
qouter<U, pack_factor, bits>(
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
}
} else {
for (int i = blocksize; i < in_vec_size; i += blocksize) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
}
if (static_cast<int>(simd_lid) < remaining) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
} else {
x_local = 0;
scale = 0;
bias = 0;
}
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
}
if (static_cast<int>(i + simd_lid) < in_vec_size) {
x_local = x[i + simd_lid];
scale = scales[(i + simd_lid) * out_vec_size_g];
bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w];
} else {
x_local = 0;
scale = 0;
bias = 0;
w_local = 0;
}
qouter<U, pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
// Accumulate in the simdgroup
#pragma clang loop unroll(full)
for (int k = 0; k < pack_factor; k++) {
for (int k = 0; k < tn * pack_factor; k++) {
result[k] = simd_sum(result[k]);
}
// Store the result
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k = 0; k < pack_factor; k++) {
for (int k = 0; k < tn * pack_factor; k++) {
y[k] = static_cast<T>(result[k]);
}
}

View File

@@ -0,0 +1,4 @@
#pragma once
#include "mlx/backend/metal/kernels/reduction/reduce_all.h"
#include "mlx/backend/metal/kernels/reduction/reduce_col.h"
#include "mlx/backend/metal/kernels/reduction/reduce_row.h"

View File

@@ -0,0 +1,293 @@
// Copyright © 2024 Apple Inc.
#include <metal_atomic>
#include <metal_simdgroup>
// clang-format off
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
#include "mlx/backend/metal/kernels/reduce.h"
#define instantiate_reduce_helper_floats(inst_f, name, op) \
inst_f(name, float16, half, op) \
inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
#define instantiate_reduce_helper_uints(inst_f, name, op) \
inst_f(name, uint8, uint8_t, op) \
inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
#define instantiate_reduce_helper_ints(inst_f, name, op) \
inst_f(name, int8, int8_t, op) \
inst_f(name, int16, int16_t, op) \
inst_f(name, int32, int32_t, op)
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) \
inst_f(name, uint64, uint64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_reduce_helper_ints(inst_f, name, op)
#define instantiate_reduce_ops(inst_f, type_f) \
type_f(inst_f, sum, Sum) \
type_f(inst_f, prod, Prod) \
type_f(inst_f, min, Min) \
type_f(inst_f, max, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint64, uint64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i_reduce_" #name)]] [[kernel]] void \
init_reduce<otype, op>( \
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \
all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \
col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template \
[[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \
col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or<bool>)
#define instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \
row_reduce_general_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \
row_reduce_general_med<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template \
[[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \
row_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template \
[[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
row_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
// clang-format on

View File

@@ -0,0 +1,6 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"

View File

@@ -1,32 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Reduce init
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And)
instantiate_init_reduce(orbool_, bool, Or) // clang-format on

View File

@@ -5,9 +5,7 @@
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
static constant constexpr const uint8_t simd_size = 32;
union bool4_or_uint {
bool4 b;
@@ -21,6 +19,7 @@ struct None {
}
};
template <typename U = bool>
struct And {
bool simd_reduce(bool val) {
return simd_all(val);
@@ -58,6 +57,7 @@ struct And {
}
};
template <typename U = bool>
struct Or {
bool simd_reduce(bool val) {
return simd_any(val);

View File

@@ -1,11 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// All reduce helper
///////////////////////////////////////////////////////////////////////////////
@@ -139,50 +133,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
out[thread_group_id] = total_val;
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@@ -1,11 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Small column reduce kernel
///////////////////////////////////////////////////////////////////////////////
@@ -52,23 +46,6 @@ template <typename T, typename U, typename Op>
out[out_idx] = total_val;
}
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce helper
///////////////////////////////////////////////////////////////////////////////
@@ -186,64 +163,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
}
}
}
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template \
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
// clang-format off
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on

View File

@@ -0,0 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}

View File

@@ -1,74 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
// clang-format off
#define instantiate_reduce_helper_floats(inst_f, name, op) \
inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
#define instantiate_reduce_helper_uints(inst_f, name, op) \
inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
#define instantiate_reduce_helper_ints(inst_f, name, op) \
inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
inst_f(name, int32, int32_t, op)
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_reduce_helper_ints(inst_f, name, op)
#define instantiate_reduce_ops(inst_f, type_f) \
type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
// clang-format on

View File

@@ -1,11 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Small row reductions
///////////////////////////////////////////////////////////////////////////////
@@ -123,33 +117,6 @@ template <typename T, typename U, typename Op>
}
}
#define instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \
row_reduce_general_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \
row_reduce_general_med<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Large row reductions
///////////////////////////////////////////////////////////////////////////////
@@ -318,61 +285,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
out[(ceildiv(gsize.y, lsize.y) * tid.x) + tid.y] = total_val;
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) template \
[[host_name("row_reduce_general_" #name)]] [[kernel]] void \
row_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) template \
[[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \
row_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@@ -1,14 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
static constant constexpr const uint8_t simd_size = 32;

View File

@@ -0,0 +1,440 @@
// Copyright © 2023-2024 Apple Inc.
template <typename U>
struct CumSum {
static constexpr constant U init = static_cast<U>(0);
template <typename T>
U operator()(U a, T b) {
return a + b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_sum(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_sum(x);
}
};
template <typename U>
struct CumProd {
static constexpr constant U init = static_cast<U>(1.0f);
template <typename T>
U operator()(U a, T b) {
return a * b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_product(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_product(x);
}
};
template <>
struct CumProd<bool> {
static constexpr constant bool init = true;
template <typename T>
bool operator()(bool a, T b) {
return a & static_cast<bool>(b);
}
bool simd_scan(bool x) {
for (int i = 1; i <= 16; i *= 2) {
bool other = simd_shuffle_up(x, i);
x &= other;
}
return x;
}
bool simd_exclusive_scan(bool x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMax {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return (a >= b) ? a : b;
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x >= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMin {
static constexpr constant U init = Limits<U>::max;
template <typename T>
U operator()(U a, T b) {
return (a <= b) ? a : b;
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x <= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] = input[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
values[i] = input[i];
}
}
}
template <typename T, typename U, int N_READS, bool reverse>
inline void load_safe(
U values[N_READS],
const device T* input,
int start,
int total,
U init) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] =
(start + N_READS - i - 1 < total) ? input[i] : init;
}
} else {
for (int i = 0; i < N_READS; i++) {
values[i] = (start + i < total) ? input[i] : init;
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_unsafe(U values[N_READS], device U* out) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
out[i] = values[N_READS - i - 1];
}
} else {
for (int i = 0; i < N_READS; i++) {
out[i] = values[i];
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_safe(U values[N_READS], device U* out, int start, int total) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
if (start + N_READS - i - 1 < total) {
out[i] = values[N_READS - i - 1];
}
}
} else {
for (int i = 0; i < N_READS; i++) {
if (start + i < total) {
out[i] = values[i];
}
}
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void contiguous_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Position the pointers
in += (gid / lsize) * axis_size;
out += (gid / lsize) * axis_size;
// Compute the number of simd_groups
uint simd_groups = lsize / simd_size;
// Allocate memory
U prefix = Op::init;
U values[N_READS];
threadgroup U simdgroup_sums[32];
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
// N_READS*lsize)
// Read block
// Compute inclusive scan of the block
// Compute inclusive scan per thread
// Compute exclusive scan of thread sums in simdgroup
// Write simdgroup sums in SM
// Compute exclusive scan of simdgroup sums
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
// value
// Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
// Compute the block offset
uint offset = r * lsize * N_READS + lid * N_READS;
// Read the values
if (reverse) {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(
values, in + axis_size - offset - N_READS);
} else {
load_safe<T, U, N_READS, reverse>(
values,
in + axis_size - offset - N_READS,
offset,
axis_size,
Op::init);
}
} else {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
} else {
load_safe<T, U, N_READS, reverse>(
values, in + offset, offset, axis_size, Op::init);
}
}
// Compute an inclusive scan per thread
for (int i = 1; i < N_READS; i++) {
values[i] = op(values[i], values[i - 1]);
}
// Compute exclusive scan of thread sums
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
// Write simdgroup_sums to SM
if (simd_lane_id == simd_size - 1) {
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute exclusive scan of simdgroup_sums
if (simd_group_id == 0) {
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
simdgroup_sums[simd_lane_id] = prev_simdgroup;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute the output
for (int i = 0; i < N_READS; i++) {
values[i] = op(values[i], prefix);
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
values[i] = op(values[i], prev_thread);
}
// Write the values
if (reverse) {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS);
} else {
write_safe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[axis_size - 1] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - 1 - N_READS);
} else {
write_safe<U, N_READS, reverse>(
values,
out + axis_size - offset - 1 - N_READS,
offset + 1,
axis_size);
}
}
} else {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset);
} else {
write_safe<U, N_READS, reverse>(
values, out + offset, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[0] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
} else {
write_safe<U, N_READS, reverse>(
values, out + offset + 1, offset + 1, axis_size);
}
}
}
// Share the prefix
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
simdgroup_sums[0] = values[N_READS - 1];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
prefix = simdgroup_sums[0];
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void strided_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]) {
Op op;
// Allocate memory
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
U values[N_READS];
U prefix[N_READS];
for (int i = 0; i < N_READS; i++) {
prefix[i] = Op::init;
}
// Compute offsets
int offset = gid.y * axis_size * stride;
int global_index_x = gid.x * lsize.y * N_READS;
for (uint j = 0; j < axis_size; j += simd_size) {
// Calculate the indices for the current thread
uint index_y = j + lid.y;
uint check_index_y = index_y;
uint index_x = global_index_x + lid.x * N_READS;
if (reverse) {
index_y = axis_size - 1 - index_y;
}
// Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
} else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
Op::init;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read strided into registers
for (int i = 0; i < N_READS; i++) {
values[i] =
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
}
// Do we need the following barrier? Shouldn't all simd threads execute
// simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan
for (int i = 0; i < N_READS; i++) {
values[i] = op.simd_scan(values[i]);
values[i] = op(values[i], prefix[i]);
prefix[i] = simd_shuffle(values[i], simd_size - 1);
}
// Write to SM
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
values[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write to device memory
if (!inclusive) {
if (check_index_y == 0) {
if ((index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
}
}
}
if (reverse) {
index_y -= 1;
check_index_y += 1;
} else {
index_y += 1;
check_index_y += 1;
}
}
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
}
}
}
}

View File

@@ -1,455 +1,19 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_math>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
// clang-format off
using namespace metal;
template <typename U>
struct CumSum {
static constexpr constant U init = static_cast<U>(0);
template <typename T>
U operator()(U a, T b) {
return a + b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_sum(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_sum(x);
}
};
template <typename U>
struct CumProd {
static constexpr constant U init = static_cast<U>(1.0f);
template <typename T>
U operator()(U a, T b) {
return a * b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_product(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_product(x);
}
};
template <>
struct CumProd<bool> {
static constexpr constant bool init = true;
template <typename T>
bool operator()(bool a, T b) {
return a & static_cast<bool>(b);
}
bool simd_scan(bool x) {
for (int i = 1; i <= 16; i *= 2) {
bool other = simd_shuffle_up(x, i);
x &= other;
}
return x;
}
bool simd_exclusive_scan(bool x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMax {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return (a >= b) ? a : b;
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x >= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMin {
static constexpr constant U init = Limits<U>::max;
template <typename T>
U operator()(U a, T b) {
return (a <= b) ? a : b;
}
U simd_scan(U x) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x <= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] = input[i];
}
} else {
for (int i = 0; i < N_READS; i++) {
values[i] = input[i];
}
}
}
template <typename T, typename U, int N_READS, bool reverse>
inline void load_safe(
U values[N_READS],
const device T* input,
int start,
int total,
U init) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] =
(start + N_READS - i - 1 < total) ? input[i] : init;
}
} else {
for (int i = 0; i < N_READS; i++) {
values[i] = (start + i < total) ? input[i] : init;
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_unsafe(U values[N_READS], device U* out) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
out[i] = values[N_READS - i - 1];
}
} else {
for (int i = 0; i < N_READS; i++) {
out[i] = values[i];
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_safe(U values[N_READS], device U* out, int start, int total) {
if (reverse) {
for (int i = 0; i < N_READS; i++) {
if (start + N_READS - i - 1 < total) {
out[i] = values[N_READS - i - 1];
}
}
} else {
for (int i = 0; i < N_READS; i++) {
if (start + i < total) {
out[i] = values[i];
}
}
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void contiguous_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Position the pointers
in += (gid / lsize) * axis_size;
out += (gid / lsize) * axis_size;
// Compute the number of simd_groups
uint simd_groups = lsize / simd_size;
// Allocate memory
U prefix = Op::init;
U values[N_READS];
threadgroup U simdgroup_sums[32];
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
// N_READS*lsize)
// Read block
// Compute inclusive scan of the block
// Compute inclusive scan per thread
// Compute exclusive scan of thread sums in simdgroup
// Write simdgroup sums in SM
// Compute exclusive scan of simdgroup sums
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
// value
// Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
// Compute the block offset
uint offset = r * lsize * N_READS + lid * N_READS;
// Read the values
if (reverse) {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(
values, in + axis_size - offset - N_READS);
} else {
load_safe<T, U, N_READS, reverse>(
values,
in + axis_size - offset - N_READS,
offset,
axis_size,
Op::init);
}
} else {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
} else {
load_safe<T, U, N_READS, reverse>(
values, in + offset, offset, axis_size, Op::init);
}
}
// Compute an inclusive scan per thread
for (int i = 1; i < N_READS; i++) {
values[i] = op(values[i], values[i - 1]);
}
// Compute exclusive scan of thread sums
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
// Write simdgroup_sums to SM
if (simd_lane_id == simd_size - 1) {
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute exclusive scan of simdgroup_sums
if (simd_group_id == 0) {
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
simdgroup_sums[simd_lane_id] = prev_simdgroup;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute the output
for (int i = 0; i < N_READS; i++) {
values[i] = op(values[i], prefix);
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
values[i] = op(values[i], prev_thread);
}
// Write the values
if (reverse) {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS);
} else {
write_safe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[axis_size - 1] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - 1 - N_READS);
} else {
write_safe<U, N_READS, reverse>(
values,
out + axis_size - offset - 1 - N_READS,
offset + 1,
axis_size);
}
}
} else {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset);
} else {
write_safe<U, N_READS, reverse>(
values, out + offset, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[0] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
} else {
write_safe<U, N_READS, reverse>(
values, out + offset + 1, offset + 1, axis_size);
}
}
}
// Share the prefix
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
simdgroup_sums[0] = values[N_READS - 1];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
prefix = simdgroup_sums[0];
}
}
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void strided_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]) {
Op op;
// Allocate memory
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
U values[N_READS];
U prefix[N_READS];
for (int i = 0; i < N_READS; i++) {
prefix[i] = Op::init;
}
// Compute offsets
int offset = gid.y * axis_size * stride;
int global_index_x = gid.x * lsize.y * N_READS;
for (uint j = 0; j < axis_size; j += simd_size) {
// Calculate the indices for the current thread
uint index_y = j + lid.y;
uint check_index_y = index_y;
uint index_x = global_index_x + lid.x * N_READS;
if (reverse) {
index_y = axis_size - 1 - index_y;
}
// Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
} else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
Op::init;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read strided into registers
for (int i = 0; i < N_READS; i++) {
values[i] =
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
}
// Do we need the following barrier? Shouldn't all simd threads execute
// simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan
for (int i = 0; i < N_READS; i++) {
values[i] = op.simd_scan(values[i]);
values[i] = op(values[i], prefix[i]);
prefix[i] = simd_shuffle(values[i], simd_size - 1);
}
// Write to SM
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
values[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write to device memory
if (!inclusive) {
if (check_index_y == 0) {
if ((index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
}
}
}
if (reverse) {
index_y -= 1;
check_index_y += 1;
} else {
index_y += 1;
check_index_y += 1;
}
}
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
} else {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
}
}
}
}
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/scan.h"
#define instantiate_contiguous_scan( \
name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \
template [[host_name("contig_scan_" #name)]] [[kernel]] void \
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
@@ -474,7 +38,6 @@ template <
uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]);
// clang-format off
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
@@ -483,9 +46,8 @@ template <
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
// clang-format off
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
@@ -537,4 +99,4 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on

View File

@@ -0,0 +1,66 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_1d_index_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
}
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
if (upd_size > 1) {
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
out_idx += out_offset;
}
auto upd_idx =
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx);
}

View File

@@ -1,236 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_1d_index_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter_1d_index( \
const device T* updates [[buffer(1)]], \
device mlx_atomic<T>* out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
\
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \
}
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl(
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
if (upd_size > 1) {
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
out_idx += out_offset;
}
auto upd_idx =
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx);
}
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter( \
const device T* updates [[buffer(1)]], \
device mlx_atomic<T>* out [[buffer(2)]], \
const constant int* upd_shape [[buffer(3)]], \
const constant size_t* upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int* out_shape [[buffer(7)]], \
const constant size_t* out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int* idx_shapes [[buffer(11)]], \
const constant size_t* idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
\
return scatter_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
upd_shape, \
upd_strides, \
upd_ndim, \
upd_size, \
out_shape, \
out_strides, \
out_ndim, \
axes, \
idxs, \
gid); \
}
#define make_scatter(n) \
make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \
make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n)
make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4)
make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8)
make_scatter(9) make_scatter(10)
/////////////////////////////////////////////////////////////////////
// Scatter instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \
scatter<src_t, idx_t, op_t, nidx>( \
const device src_t* updates [[buffer(1)]], \
device mlx_atomic<src_t>* out [[buffer(2)]], \
const constant int* upd_shape [[buffer(3)]], \
const constant size_t* upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int* out_shape [[buffer(7)]], \
const constant size_t* out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int* idx_shapes [[buffer(11)]], \
const constant size_t* idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \
scatter_1d_index<src_t, idx_t, op_t, nidx>( \
const device src_t* updates [[buffer(1)]], \
device mlx_atomic<src_t>* out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
// clang-format off
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on
// clang-format off
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) // clang-format on
// clang-format off
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on
// clang-format off
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>) // clang-format on
// clang-format off
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t) // clang-format on
// clang-format off
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
instantiate_scatter_nd0(uint16, uint16_t)
instantiate_scatter_nd0(uint32, uint32_t)
instantiate_scatter_nd0(int8, int8_t)
instantiate_scatter_nd0(int16, int16_t)
instantiate_scatter_nd0(int32, int32_t)
instantiate_scatter_nd0(float16, half)
instantiate_scatter_nd0(float32, float)
instantiate_scatter_nd0(bfloat16, bfloat16_t)
instantiate_scatter(bool_, bool)
instantiate_scatter(uint8, uint8_t)
instantiate_scatter(uint16, uint16_t)
instantiate_scatter(uint32, uint32_t)
instantiate_scatter(int8, int8_t)
instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t) // clang-format on

View File

@@ -0,0 +1,190 @@
// Copyright © 2023-2024 Apple Inc.
template <typename T>
inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return fast::exp(x);
}
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
: Limits<AccT>::finite_min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::finite_min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
AccT exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x;
normalizer += exp_x;
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
local_normalizer[0] = normalizer;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = 1 / local_normalizer[0];
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[i] = T(ld[i] * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = T(ld[i] * normalizer);
}
}
}
}
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += softmax_exp(vals[i] - maxval);
}
}
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
// lsize) parts. We need to combine them.
// 1. We start by finding the max across simd groups
// 2. We then change the partial normalizers to account for a possible
// change in max
// 3. We sum all normalizers
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= softmax_exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
// Now the normalizer and max value is correct for each simdgroup. We write
// them shared memory and combine them.
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= softmax_exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
normalizer = 1 / normalizer;
// Finally given the normalizer and max value we can directly write the
// softmax output
out += gid * axis_size;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) {
out[offset + i] =
T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
}
}
}
}

View File

@@ -1,205 +1,18 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_common>
#include <metal_simdgroup>
using namespace metal;
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename T>
inline T softmax_exp(T x) {
// Softmax doesn't need high precision exponential cause x is gonna be in
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
return fast::exp(x);
}
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_single_row(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
int lid = _lid;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
AccT ld[N_READS];
in += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
ld[i] = AccT(in[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
ld[i] = ((lid * N_READS + i) < axis_size) ? AccT(in[i])
: Limits<AccT>::finite_min;
}
}
if (simd_group_id == 0) {
local_max[simd_lane_id] = Limits<AccT>::finite_min;
local_normalizer[simd_lane_id] = 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Get the max
AccT maxval = Limits<AccT>::finite_min;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < ld[i]) ? ld[i] : maxval;
}
maxval = simd_max(maxval);
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
maxval = simd_max(local_max[simd_lane_id]);
if (simd_lane_id == 0) {
local_max[0] = maxval;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = local_max[0];
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
AccT normalizer = 0;
for (int i = 0; i < N_READS; i++) {
AccT exp_x = softmax_exp(ld[i] - maxval);
ld[i] = exp_x;
normalizer += exp_x;
}
normalizer = simd_sum(normalizer);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_group_id == 0) {
normalizer = simd_sum(local_normalizer[simd_lane_id]);
if (simd_lane_id == 0) {
local_normalizer[0] = normalizer;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = 1 / local_normalizer[0];
// Normalize and write to the output
out += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[i] = T(ld[i] * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
out[i] = T(ld[i] * normalizer);
}
}
}
}
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
[[kernel]] void softmax_looped(
const device T* in,
device T* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
in += gid * axis_size;
constexpr int SIMD_SIZE = 32;
threadgroup AccT local_max[SIMD_SIZE];
threadgroup AccT local_normalizer[SIMD_SIZE];
// Get the max and the normalizer in one go
AccT prevmax;
AccT maxval = Limits<AccT>::finite_min;
AccT normalizer = 0;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
AccT vals[N_READS];
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
vals[i] = AccT(in[offset + i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
vals[i] = (offset + i < axis_size) ? AccT(in[offset + i])
: Limits<AccT>::finite_min;
}
}
prevmax = maxval;
for (int i = 0; i < N_READS; i++) {
maxval = (maxval < vals[i]) ? vals[i] : maxval;
}
normalizer *= softmax_exp(prevmax - maxval);
for (int i = 0; i < N_READS; i++) {
normalizer += softmax_exp(vals[i] - maxval);
}
}
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
// lsize) parts. We need to combine them.
// 1. We start by finding the max across simd groups
// 2. We then change the partial normalizers to account for a possible
// change in max
// 3. We sum all normalizers
prevmax = maxval;
maxval = simd_max(maxval);
normalizer *= softmax_exp(prevmax - maxval);
normalizer = simd_sum(normalizer);
// Now the normalizer and max value is correct for each simdgroup. We write
// them shared memory and combine them.
prevmax = maxval;
if (simd_lane_id == 0) {
local_max[simd_group_id] = maxval;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
maxval = simd_max(local_max[simd_lane_id]);
normalizer *= softmax_exp(prevmax - maxval);
if (simd_lane_id == 0) {
local_normalizer[simd_group_id] = normalizer;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
normalizer = simd_sum(local_normalizer[simd_lane_id]);
normalizer = 1 / normalizer;
// Finally given the normalizer and max value we can directly write the
// softmax output
out += gid * axis_size;
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
r++) {
int offset = r * lsize * N_READS + lid * N_READS;
if (offset + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if (offset + i < axis_size) {
out[offset + i] =
T(softmax_exp(in[offset + i] - maxval) * normalizer);
}
}
}
}
}
#include "mlx/backend/metal/kernels/softmax.h"
#define instantiate_softmax(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
template [[host_name("block_softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
@@ -208,7 +21,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
template [[host_name("looped_softmax_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
device itype* out, \
@@ -220,7 +33,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_precise(name, itype) \
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
template [[host_name("block_softmax_precise_" #name)]] [[kernel]] void \
softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
@@ -229,7 +42,7 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
template [[host_name("looped_softmax_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
device itype* out, \
@@ -240,7 +53,6 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
// clang-format off
instantiate_softmax(float32, float)
instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)

View File

@@ -0,0 +1,674 @@
// Copyright © 2023-2024 Apple Inc.
#define MLX_MTL_CONST static constant constexpr const
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
using namespace metal;
// Based on GPU merge sort algorithm at
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
///////////////////////////////////////////////////////////////////////////////
// Thread-level sort
///////////////////////////////////////////////////////////////////////////////
template <typename T>
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
T w = a;
a = b;
b = w;
}
template <typename T>
struct LessThan {
static constexpr constant T init = Limits<T>::max;
METAL_FUNC bool operator()(T a, T b) {
return a < b;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short N_PER_THREAD,
typename CompareOp>
struct ThreadSort {
static METAL_FUNC void sort(
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
MLX_MTL_LOOP_UNROLL
for (short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if (op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Threadgroup-level sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp>
struct BlockMergeSort {
using thread_sort_t =
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
static METAL_FUNC int merge_partition(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
short A_sz,
short B_sz,
short sort_md) {
CompareOp op;
short A_st = max(0, sort_md - B_sz);
short A_ed = min(sort_md, A_sz);
while (A_st < A_ed) {
short md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
static METAL_FUNC void merge_step(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
const threadgroup idx_t* As_idx,
const threadgroup idx_t* Bs_idx,
short A_sz,
short B_sz,
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
short a_idx = 0;
short b_idx = 0;
for (int i = 0; i < N_PER_THREAD; ++i) {
auto a = As[a_idx];
auto b = Bs[b_idx];
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
vals[i] = pred ? b : a;
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
b_idx += short(pred);
a_idx += short(!pred);
}
}
static METAL_FUNC void sort(
threadgroup val_t* tgp_vals [[threadgroup(0)]],
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
int size_sorted_axis,
uint3 lid [[thread_position_in_threadgroup]]) {
// Get thread location
int idx = lid.x * N_PER_THREAD;
// Load from shared memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for (int i = 0; i < N_PER_THREAD; ++i) {
thread_vals[i] = tgp_vals[idx + i];
if (ARG_SORT) {
thread_idxs[i] = tgp_idxs[idx + i];
}
}
// Per thread sort
if (idx < size_sorted_axis) {
thread_sort_t::sort(thread_vals, thread_idxs);
}
// Do merges using threadgroup memory
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
merge_threads *= 2) {
// Update threadgroup memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Find location in merge step
int merge_group = lid.x / merge_threads;
int merge_lane = lid.x % merge_threads;
int sort_sz = N_PER_THREAD * merge_threads;
int sort_st = N_PER_THREAD * merge_threads * merge_group;
// As = tgp_vals[A_st:A_ed] is sorted
// Bs = tgp_vals[B_st:B_ed] is sorted
int A_st = sort_st;
int A_ed = sort_st + sort_sz / 2;
int B_st = sort_st + sort_sz / 2;
int B_ed = sort_st + sort_sz;
const threadgroup val_t* As = tgp_vals + A_st;
const threadgroup val_t* Bs = tgp_vals + B_st;
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Find a partition of merge elements
// Ci = merge(As[partition:], Bs[sort_md - partition:])
// of size N_PER_THREAD for each merge lane i
// C = [Ci] is sorted
int sort_md = N_PER_THREAD * merge_lane;
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
As += partition;
Bs += sort_md - partition;
A_sz -= partition;
B_sz -= sort_md - partition;
const threadgroup idx_t* As_idx =
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx =
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
// Merge starting at the partition and store results in thread registers
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
}
// Write out to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Kernel sort
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<T>>
struct KernelMergeSort {
using val_t = T;
using idx_t = uint;
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device T* inp,
device U* out,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
const constant int& stride_segment_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis;
// Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
: val_t(CompareOp::init);
if (ARG_SORT) {
tgp_idxs[i] = i;
}
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
if (ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i];
} else {
out[i * stride_sorted_axis] = tgp_vals[i];
}
}
}
};
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
nullptr,
tid,
lid);
}
}
constant constexpr const int zero_helper = 0;
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out += block_idx;
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
nullptr,
tid,
lid);
}
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
struct KernelMultiBlockMergeSort {
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device val_t* inp,
device val_t* out_vals,
device idx_t* out_idxs,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
int base_idx = tid.x * N_PER_BLOCK;
// Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
: val_t(CompareOp::init);
tgp_idxs[i] = idx;
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if (idx < size_sorted_axis) {
out_vals[idx] = tgp_vals[i];
out_idxs[idx] = tgp_idxs[i];
}
}
}
static METAL_FUNC int merge_partition(
const device val_t* As,
const device val_t* Bs,
int A_sz,
int B_sz,
int sort_md) {
CompareOp op;
int A_st = max(0, sort_md - B_sz);
int A_ed = min(sort_md, A_sz);
while (A_st < A_ed) {
int md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
const device val_t* inp [[buffer(0)]],
device val_t* out_vals [[buffer(1)]],
device idx_t* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out_vals += tid.y * size_sorted_axis;
out_idxs += tid.y * size_sorted_axis;
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out_vals,
out_idxs,
size_sorted_axis,
stride_sorted_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_partition(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
block_partitions += tid.y * tgp_dims.x;
dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis;
// Find location in merge step
int merge_group = lid.x / merge_tiles;
int merge_lane = lid.x % merge_tiles;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
block_partitions[lid.x] = A_st + partition;
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_merge(
const device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals_in [[buffer(1)]],
const device idx_t* dev_idxs_in [[buffer(2)]],
device val_t* dev_vals_out [[buffer(3)]],
device idx_t* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
using block_sort_t = typename sort_kernel::block_merge_sort_t;
block_partitions += tid.y * (num_tiles + 1);
dev_vals_in += tid.y * size_sorted_axis;
dev_idxs_in += tid.y * size_sorted_axis;
dev_vals_out += tid.y * size_sorted_axis;
dev_idxs_out += tid.y * size_sorted_axis;
int block_idx = tid.x;
int merge_group = block_idx / merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
int A_st = block_partitions[block_idx + 0];
int A_ed = block_partitions[block_idx + 1];
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
int B_ed = min(
size_sorted_axis,
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
if ((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
B_ed = min(size_sorted_axis, sort_st + sort_sz);
}
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Load from global memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
if (idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
: dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
: dev_idxs_in[B_st + idx - A_sz];
} else {
thread_vals[i] = CompareOp::init;
thread_idxs[i] = 0;
}
}
// Write to shared memory
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
tgp_vals[idx] = thread_vals[i];
tgp_idxs[idx] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Merge
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
int A_st_local = block_sort_t::merge_partition(
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
int A_ed_local = A_sz;
int B_st_local = sort_md_local - A_st_local;
int B_ed_local = B_sz;
int A_sz_local = A_ed_local - A_st_local;
int B_sz_local = B_ed_local - B_st_local;
// Do merge
block_sort_t::merge_step(
tgp_vals + A_st_local,
tgp_vals + A_ed_local + B_st_local,
tgp_idxs + A_st_local,
tgp_idxs + A_ed_local + B_st_local,
A_sz_local,
B_sz_local,
thread_vals,
thread_idxs);
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
int idx = lid.x * N_PER_THREAD;
tgp_vals[idx + i] = thread_vals[i];
tgp_idxs[idx + i] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if (idx < size_sorted_axis) {
dev_vals_out[idx] = tgp_vals[i];
dev_idxs_out[idx] = tgp_idxs[i];
}
}
}

View File

@@ -1,392 +1,16 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#define MLX_MTL_CONST static constant constexpr const
#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)")
using namespace metal;
// Based on GPU merge sort algorithm at
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
///////////////////////////////////////////////////////////////////////////////
// Thread-level sort
///////////////////////////////////////////////////////////////////////////////
template <typename T>
METAL_FUNC void thread_swap(thread T& a, thread T& b) {
T w = a;
a = b;
b = w;
}
template <typename T>
struct LessThan {
static constexpr constant T init = Limits<T>::max;
METAL_FUNC bool operator()(T a, T b) {
return a < b;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short N_PER_THREAD,
typename CompareOp>
struct ThreadSort {
static METAL_FUNC void sort(
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
MLX_MTL_LOOP_UNROLL
for (short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if (op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Threadgroup-level sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp>
struct BlockMergeSort {
using thread_sort_t =
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
static METAL_FUNC int merge_partition(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
short A_sz,
short B_sz,
short sort_md) {
CompareOp op;
short A_st = max(0, sort_md - B_sz);
short A_ed = min(sort_md, A_sz);
while (A_st < A_ed) {
short md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
static METAL_FUNC void merge_step(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
const threadgroup idx_t* As_idx,
const threadgroup idx_t* Bs_idx,
short A_sz,
short B_sz,
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
short a_idx = 0;
short b_idx = 0;
for (int i = 0; i < N_PER_THREAD; ++i) {
auto a = As[a_idx];
auto b = Bs[b_idx];
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
vals[i] = pred ? b : a;
idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx];
b_idx += short(pred);
a_idx += short(!pred);
}
}
static METAL_FUNC void sort(
threadgroup val_t* tgp_vals [[threadgroup(0)]],
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
int size_sorted_axis,
uint3 lid [[thread_position_in_threadgroup]]) {
// Get thread location
int idx = lid.x * N_PER_THREAD;
// Load from shared memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for (int i = 0; i < N_PER_THREAD; ++i) {
thread_vals[i] = tgp_vals[idx + i];
if (ARG_SORT) {
thread_idxs[i] = tgp_idxs[idx + i];
}
}
// Per thread sort
if (idx < size_sorted_axis) {
thread_sort_t::sort(thread_vals, thread_idxs);
}
// Do merges using threadgroup memory
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
merge_threads *= 2) {
// Update threadgroup memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Find location in merge step
int merge_group = lid.x / merge_threads;
int merge_lane = lid.x % merge_threads;
int sort_sz = N_PER_THREAD * merge_threads;
int sort_st = N_PER_THREAD * merge_threads * merge_group;
// As = tgp_vals[A_st:A_ed] is sorted
// Bs = tgp_vals[B_st:B_ed] is sorted
int A_st = sort_st;
int A_ed = sort_st + sort_sz / 2;
int B_st = sort_st + sort_sz / 2;
int B_ed = sort_st + sort_sz;
const threadgroup val_t* As = tgp_vals + A_st;
const threadgroup val_t* Bs = tgp_vals + B_st;
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Find a partition of merge elements
// Ci = merge(As[partition:], Bs[sort_md - partition:])
// of size N_PER_THREAD for each merge lane i
// C = [Ci] is sorted
int sort_md = N_PER_THREAD * merge_lane;
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
As += partition;
Bs += sort_md - partition;
A_sz -= partition;
B_sz -= sort_md - partition;
const threadgroup idx_t* As_idx =
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx =
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
// Merge starting at the partition and store results in thread registers
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
}
// Write out to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
// Kernel sort
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<T>>
struct KernelMergeSort {
using val_t = T;
using idx_t = uint;
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device T* inp,
device U* out,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
const constant int& stride_segment_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis;
// Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
: val_t(CompareOp::init);
if (ARG_SORT) {
tgp_idxs[i] = i;
}
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
if (ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i];
} else {
out[i * stride_sorted_axis] = tgp_vals[i];
}
}
}
};
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
stride_segment_axis,
tgp_vals,
nullptr,
tid,
lid);
}
}
constant constexpr const int zero_helper = 0;
template <
typename T,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc(
const device T* inp [[buffer(0)]],
device U* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out += block_idx;
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
tgp_idxs,
tid,
lid);
} else {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out,
size_sorted_axis,
stride_sorted_axis,
zero_helper,
tgp_vals,
nullptr,
tid,
lid);
}
}
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#include "mlx/backend/metal/kernels/sort.h"
#define instantiate_block_sort( \
name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \
template [[host_name("c" #name "_" #itname "_" #otname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
@@ -396,8 +20,8 @@ template <
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
"_nc")]] [[kernel]] void \
template [[host_name("nc" #name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
)]] [[kernel]] void \
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
@@ -411,22 +35,20 @@ template <
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort( \
arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
arg_block_sort, itname, itype, uint32, uint32_t, true, bn, tn)
#define instantiate_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort( \
block_merge_sort, itname, itype, itname, itype, false, bn, tn)
_block_sort, itname, itype, itname, itype, false, bn, tn)
// clang-format off
#define instantiate_block_sort_tn(itname, itype, bn) \
instantiate_block_sort_base(itname, itype, bn, 8) \
instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on
instantiate_arg_block_sort_base(itname, itype, bn, 8)
// clang-format off
#define instantiate_block_sort_bn(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_bn(uint8, uint8_t)
instantiate_block_sort_bn(uint16, uint16_t)
@@ -436,321 +58,18 @@ instantiate_block_sort_bn(int16, int16_t)
instantiate_block_sort_bn(int32, int32_t)
instantiate_block_sort_bn(float16, half)
instantiate_block_sort_bn(float32, float)
instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on
// clang-format off
instantiate_block_sort_bn(bfloat16, bfloat16_t)
#define instantiate_block_sort_long(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256)
instantiate_block_sort_long(uint64, uint64_t)
instantiate_block_sort_long(int64, int64_t) // clang-format on
///////////////////////////////////////////////////////////////////////////////
// Multi block merge sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
struct KernelMultiBlockMergeSort {
using block_merge_sort_t = BlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
const device val_t* inp,
device val_t* out_vals,
device idx_t* out_idxs,
const constant int& size_sorted_axis,
const constant int& stride_sorted_axis,
threadgroup val_t* tgp_vals,
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
int base_idx = tid.x * N_PER_BLOCK;
// Copy into threadgroup memory
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
: val_t(CompareOp::init);
tgp_idxs[i] = idx;
}
// Sort elements within the block
threadgroup_barrier(mem_flags::mem_threadgroup);
block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if (idx < size_sorted_axis) {
out_vals[idx] = tgp_vals[i];
out_idxs[idx] = tgp_idxs[i];
}
}
}
static METAL_FUNC int merge_partition(
const device val_t* As,
const device val_t* Bs,
int A_sz,
int B_sz,
int sort_md) {
CompareOp op;
int A_st = max(0, sort_md - B_sz);
int A_ed = min(sort_md, A_sz);
while (A_st < A_ed) {
int md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
}
}
return A_ed;
}
};
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort(
const device val_t* inp [[buffer(0)]],
device val_t* out_vals [[buffer(1)]],
device idx_t* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
out_vals += tid.y * size_sorted_axis;
out_idxs += tid.y * size_sorted_axis;
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out_vals,
out_idxs,
size_sorted_axis,
stride_sorted_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_partition(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
block_partitions += tid.y * tgp_dims.x;
dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis;
// Find location in merge step
int merge_group = lid.x / merge_tiles;
int merge_lane = lid.x % merge_tiles;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
block_partitions[lid.x] = A_st + partition;
}
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_merge(
const device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals_in [[buffer(1)]],
const device idx_t* dev_idxs_in [[buffer(2)]],
device val_t* dev_vals_out [[buffer(3)]],
device idx_t* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
using block_sort_t = typename sort_kernel::block_merge_sort_t;
block_partitions += tid.y * (num_tiles + 1);
dev_vals_in += tid.y * size_sorted_axis;
dev_idxs_in += tid.y * size_sorted_axis;
dev_vals_out += tid.y * size_sorted_axis;
dev_idxs_out += tid.y * size_sorted_axis;
int block_idx = tid.x;
int merge_group = block_idx / merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
int A_st = block_partitions[block_idx + 0];
int A_ed = block_partitions[block_idx + 1];
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
int B_ed = min(
size_sorted_axis,
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
if ((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
B_ed = min(size_sorted_axis, sort_st + sort_sz);
}
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Load from global memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
if (idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
: dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
: dev_idxs_in[B_st + idx - A_sz];
} else {
thread_vals[i] = CompareOp::init;
thread_idxs[i] = 0;
}
}
// Write to shared memory
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
tgp_vals[idx] = thread_vals[i];
tgp_idxs[idx] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Merge
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
int A_st_local = block_sort_t::merge_partition(
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
int A_ed_local = A_sz;
int B_st_local = sort_md_local - A_st_local;
int B_ed_local = B_sz;
int A_sz_local = A_ed_local - A_st_local;
int B_sz_local = B_ed_local - B_st_local;
// Do merge
block_sort_t::merge_step(
tgp_vals + A_st_local,
tgp_vals + A_ed_local + B_st_local,
tgp_idxs + A_st_local,
tgp_idxs + A_ed_local + B_st_local,
A_sz_local,
B_sz_local,
thread_vals,
thread_idxs);
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < N_PER_THREAD; ++i) {
int idx = lid.x * N_PER_THREAD;
tgp_vals[idx + i] = thread_vals[i];
tgp_idxs[idx + i] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if (idx < size_sorted_axis) {
dev_vals_out[idx] = tgp_vals[i];
dev_idxs_out[idx] = tgp_idxs[i];
}
}
}
instantiate_block_sort_long(int64, int64_t)
#define instantiate_multi_block_sort( \
vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \
template [[host_name("sort_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \
@@ -763,7 +82,7 @@ mb_block_merge(
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \
template [[host_name("partition_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype * block_partitions [[buffer(0)]], \
@@ -774,7 +93,7 @@ mb_block_merge(
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \
template [[host_name("merge_mbsort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
@@ -788,7 +107,6 @@ mb_block_merge(
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
@@ -800,11 +118,10 @@ instantiate_multi_block_sort_base(int16, int16_t)
instantiate_multi_block_sort_base(int32, int32_t)
instantiate_multi_block_sort_base(float16, half)
instantiate_multi_block_sort_base(float32, float)
instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
// clang-format off
#define instantiate_multi_block_sort_long(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
instantiate_multi_block_sort_long(uint64, uint64_t)
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on

View File

@@ -2,10 +2,12 @@
#pragma once
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/loader.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
using namespace metal;
using namespace mlx::steel;
using namespace mlx::steel;

View File

@@ -0,0 +1,176 @@
// Copyright © 2024 Apple Inc.
#include <metal_stdlib>
using namespace metal;
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using namespace mlx::steel;
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_a>,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>>>;
// Weight loader
using loader_b_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_b>,
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
const int N = gemm_params->N;
const int C_per_group = params->C / params->groups;
// Groups
A += tid.z * C_per_group;
B += tid.z * N * K;
C += tid.z * N;
B += c_col * K;
C += c_row * (N * params->groups) + c_col;
const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col);
const int ldc = N * params->groups;
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
}

View File

@@ -2,177 +2,13 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
using namespace metal;
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using namespace mlx::steel;
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_a>,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>>>;
// Weight loader
using loader_b_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_b>,
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
const int N = gemm_params->N;
B += c_col * K;
C += c_row * N + c_col;
const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col);
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
}
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h"
#define instantiate_implicit_conv_2d( \
name, \
@@ -200,25 +36,22 @@ implicit_gemm_conv_2d(
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
// clang-format off
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
// clang-format off
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -0,0 +1,188 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d_general(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t =
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Weight loader
using loader_b_t =
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int tid_z = tid.z;
const int base_oh = tid_z / jump_params->f_out_jump_w;
const int base_ow = tid_z % jump_params->f_out_jump_w;
const int base_wh = base_h[base_oh].weight_base;
const int base_ww = base_w[base_ow].weight_base;
const int base_wh_size = base_h[base_oh].weight_size;
const int base_ww_size = base_w[base_ow].weight_size;
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
B += c_col * K;
const int4 offsets_a(0, c_row, base_oh, base_ow);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(
A,
As,
offsets_a,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
loader_b_t loader_b(
B,
Bs,
offsets_b,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations =
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
{
// Adjust for simdgroup and thread locatio
int offset_m = c_row + mma_op.sm + mma_op.tm;
int offset_n = c_col + mma_op.sn + mma_op.tn;
C += offset_n;
if (offset_n >= gemm_params->N)
return;
short diff = gemm_params->N - offset_n;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < mma_t::TM; i++) {
int cm = offset_m + i * mma_t::TM_stride;
int n = cm / jump_params->adj_out_hw;
int hw = cm % jump_params->adj_out_hw;
int oh =
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow =
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] +
oh * params->out_strides[1] + ow * params->out_strides[2];
STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum =
mma_op.results[i * mma_t::TN + j].thread_elements();
int offset = offset_cm + (j * mma_t::TN_stride);
// Apply epilogue and output C
if (j * mma_t::TN_stride < diff) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * mma_t::TN_stride + 1 < diff) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
}

View File

@@ -2,201 +2,18 @@
#include <metal_stdlib>
// clang-format off
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h"
using namespace metal;
using namespace mlx::steel;
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d_general(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
constexpr short tgp_padding_b = 16 / sizeof(T);
constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
constexpr short shape_a_rows = (transpose_a ? BK : BM);
constexpr short shape_b_rows = (transpose_b ? BN : BK);
constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t =
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Weight loader
using loader_b_t =
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
using mma_t = BlockMMA<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
return;
}
const int tid_z = tid.z;
const int base_oh = tid_z / jump_params->f_out_jump_w;
const int base_ow = tid_z % jump_params->f_out_jump_w;
const int base_wh = base_h[base_oh].weight_base;
const int base_ww = base_w[base_ow].weight_base;
const int base_wh_size = base_h[base_oh].weight_size;
const int base_ww_size = base_w[base_ow].weight_size;
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int K = gemm_params->K;
B += c_col * K;
const int4 offsets_a(0, c_row, base_oh, base_ow);
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(
A,
As,
offsets_a,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
loader_b_t loader_b(
B,
Bs,
offsets_b,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations =
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Store results to device memory
{
// Adjust for simdgroup and thread locatio
int offset_m = c_row + mma_op.sm + mma_op.tm;
int offset_n = c_col + mma_op.sn + mma_op.tn;
C += offset_n;
if (offset_n >= gemm_params->N)
return;
short diff = gemm_params->N - offset_n;
STEEL_PRAGMA_UNROLL
for (int i = 0; i < mma_t::TM; i++) {
int cm = offset_m + i * mma_t::TM_stride;
int n = cm / jump_params->adj_out_hw;
int hw = cm % jump_params->adj_out_hw;
int oh =
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow =
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] +
oh * params->out_strides[1] + ow * params->out_strides[2];
STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum =
mma_op.results[i * mma_t::TN + j].thread_elements();
int offset = offset_cm + (j * mma_t::TN_stride);
// Apply epilogue and output C
if (j * mma_t::TN_stride < diff) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * mma_t::TN_stride + 1 < diff) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template \
[[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
@@ -218,16 +35,14 @@ implicit_gemm_conv_2d_general(
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
// clang-format off
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
// clang-format off
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -2,9 +2,7 @@
#pragma once
#include "mlx/backend/metal/kernels/steel/utils.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
@@ -285,4 +283,4 @@ struct Conv2DWeightBlockLoaderGeneral {
};
} // namespace steel
} // namespace mlx
} // namespace mlx

View File

@@ -0,0 +1,4 @@
// Copyright © 2024 Apple Inc.
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")

View File

@@ -0,0 +1,415 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool has_batch [[function_constant(10)]];
constant bool use_out_source [[function_constant(100)]];
constant bool do_axpby [[function_constant(110)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2), function_constant(use_out_source)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
// Find block
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
// Exit early if out of bounds
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Adjust for batch
// Handle gather
if (do_gather) {
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) {
const constant size_t* indx_A_bstrides = batch_strides;
const constant size_t* indx_B_bstrides =
batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant size_t* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant size_t* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
}
}
// Handle regular batch
else {
if (has_batch) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
}
}
D += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
if (use_out_source) {
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
}
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
const TransformAdd<AccumType, AccumType> epilogue_op_add(
addmm_params->alpha, addmm_params->beta);
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
addmm_params->alpha, addmm_params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (align_M && align_N) {
// Do gemm
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
const int leftover_bk = 0;
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
// Do gemm
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
} else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@@ -1,430 +1,12 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool has_batch [[function_constant(10)]];
constant bool use_out_source [[function_constant(100)]];
constant bool do_axpby [[function_constant(110)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2), function_constant(use_out_source)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
// Find block
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
// Exit early if out of bounds
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Adjust for batch
// Handle gather
if (do_gather) {
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) {
const constant size_t* indx_A_bstrides = batch_strides;
const constant size_t* indx_B_bstrides =
batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant size_t* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant size_t* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
}
}
// Handle regular batch
else {
if (has_batch) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
}
}
D += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
if (use_out_source) {
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
}
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
const TransformAdd<AccumType, AccumType> epilogue_op_add(
addmm_params->alpha, addmm_params->beta);
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
addmm_params->alpha, addmm_params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (align_M && align_N) {
// Do gemm
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
const int leftover_bk = 0;
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
// Do gemm
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
} else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
// clang-format off
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
@@ -445,24 +27,23 @@ template <
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); // clang-format on
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
// clang-format on

View File

@@ -0,0 +1,719 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/steel/defines.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
struct _NoMask {
char x;
constexpr METAL_FUNC operator bool() {
return true;
}
constexpr METAL_FUNC operator bool() const threadgroup {
return true;
}
constexpr METAL_FUNC operator bool() const device {
return true;
}
constexpr METAL_FUNC operator bool() const constant {
return true;
}
};
template <typename OutT, typename InT = OutT>
struct ScaleOp {
OutT scale;
METAL_FUNC OutT apply(InT x) const {
return static_cast<OutT>(x) * scale;
}
};
typedef struct _NoMask nomask_t;
template <
typename T,
typename out_mask_t,
typename op_mask_t,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
block_masked_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device out_mask_t* out_mask [[buffer(10)]],
const device op_mask_t* lhs_mask [[buffer(11)]],
const device op_mask_t* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Appease the compiler
(void)lid;
static_assert(
BM == BN,
"block_masked_gemm must have the same block M and block N size");
static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
constexpr bool has_mul_operand_mask =
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
constexpr bool has_mul_output_mask =
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
constexpr short k_mask_factor = short(BM / BK);
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
const constant size_t* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
if (params->batch_ndim > 1) {
if (has_output_mask) {
out_mask += elem_to_loc(
tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
mask_batch_strides += params->batch_ndim;
}
if (has_operand_mask) {
const constant size_t* mask_strides_lhs = mask_batch_strides;
const constant size_t* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
mask_strides_lhs,
mask_strides_rhs,
params->batch_ndim);
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
} else {
if (has_output_mask) {
out_mask += tid.z * mask_batch_strides[0];
mask_batch_strides += params->batch_ndim;
}
if (has_operand_mask) {
lhs_mask += tid.z * mask_batch_strides[0];
rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
}
}
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
const constant int* out_mask_strides = mask_strides;
const constant int* lhs_mask_strides =
mask_strides + (has_output_mask ? 2 : 0);
const constant int* rhs_mask_strides =
lhs_mask_strides + (has_operand_mask ? 2 : 0);
const int out_mask_offset = !has_output_mask
? 0
: tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
short k_factor_cnt = k_mask_factor;
ScaleOp<float> out_mask_op;
ScaleOp<T> lhs_mask_op;
ScaleOp<T> rhs_mask_op;
if (has_output_mask) {
auto mask_out = out_mask[out_mask_offset];
if (has_mul_output_mask) {
out_mask_op.scale = float(mask_out);
}
// Write zeros and return
if (!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for (short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
}
return;
}
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(
A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(
B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm =
MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
const short tgp_bn =
MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// Do unaligned K iterations first
if (!K_aligned) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int mask_idx_last = k_last / BM;
if (!has_operand_mask ||
(bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale =
lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
rhs_mask_op.scale =
rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
}
// Move loader source ahead to end
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
}
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(bool(lhs_mask[lhs_mask_offset]) &&
bool(rhs_mask[rhs_mask_offset]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
}
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
k_factor_cnt--;
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
}
if (has_mul_output_mask) {
mma_op.apply_epilogue(out_mask_op);
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else {
const bool M_aligned = (tgp_bm == BM);
const bool N_aligned = (tgp_bn == BN);
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(bool(lhs_mask[lhs_mask_offset]) &&
bool(rhs_mask[rhs_mask_offset]))) {
if (has_mul_operand_mask) {
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
}
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
if (has_mul_operand_mask) {
loader_a.apply_inplace_op(lhs_mask_op);
loader_b.apply_inplace_op(rhs_mask_op);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
k_factor_cnt--;
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
}
if (has_mul_output_mask) {
mma_op.apply_epilogue(out_mask_op);
}
if (M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
bool has_operand_mask = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
block_masked_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device bool* out_mask [[buffer(10)]],
const device bool* lhs_mask [[buffer(11)]],
const device bool* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Appease the compiler
(void)lid;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
if (params->batch_ndim > 1) {
const constant size_t* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
if (has_operand_mask) {
const constant size_t* mask_strides_lhs =
mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
mask_strides_lhs,
mask_strides_rhs,
params->batch_ndim);
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
} else {
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
if (has_operand_mask) {
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
}
}
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
// Write zeros and return
if (!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for (short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
}
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(
A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(
B, params->ldb, Bs, simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
bool M_aligned = (tgp_bm == BM);
bool N_aligned = (tgp_bn == BN);
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
if (M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}

View File

@@ -1,316 +1,16 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
bool has_operand_mask = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
block_masked_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device bool* out_mask [[buffer(10)]],
const device bool* lhs_mask [[buffer(11)]],
const device bool* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Appease the compiler
(void)lid;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
if (params->batch_ndim > 1) {
const constant size_t* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
if (has_operand_mask) {
const constant size_t* mask_strides_lhs =
mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
mask_strides_lhs,
mask_strides_rhs,
params->batch_ndim);
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
} else {
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
if (has_operand_mask) {
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
}
}
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
// Write zeros and return
if (!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for (short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
}
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(
A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(
B, params->ldb, Bs, simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
bool M_aligned = (tgp_bm == BM);
bool N_aligned = (tgp_bn == BN);
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
if (M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h"
#define instantiate_gemm( \
outmaskname, \
outmasktype, \
opmaskname, \
opmasktype, \
tname, \
trans_a, \
trans_b, \
@@ -326,15 +26,15 @@ block_masked_gemm(
aname, \
mn_aligned, \
kname, \
k_aligned, \
omname, \
op_mask) \
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \
k_aligned) \
template [[host_name("steel_gemm_block_outmask_" #outmaskname \
"_opmask_" #opmaskname "_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname \
"_op_mask_" #omname)]] [[kernel]] void \
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
block_masked_gemm< \
itype, \
outmasktype, \
opmasktype, \
bm, \
bn, \
bk, \
@@ -343,48 +43,48 @@ block_masked_gemm(
trans_a, \
trans_b, \
mn_aligned, \
k_aligned, \
op_mask>( \
k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const device bool* out_mask [[buffer(10)]], \
const device bool* lhs_mask [[buffer(11)]], \
const device bool* rhs_mask [[buffer(12)]], \
const device outmasktype* out_mask [[buffer(10)]], \
const device opmasktype* lhs_mask [[buffer(11)]], \
const device opmasktype* rhs_mask [[buffer(12)]], \
const constant int* mask_strides [[buffer(13)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(bool_, bool, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(iname, itype, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(bool_, bool, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(iname, itype, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(nomask, nomask_t, bool_, bool, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(nomask, nomask_t, iname, itype, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(bool_, bool, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(iname, itype, nomask, nomask_t, tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned)
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

Some files were not shown because too many files have changed in this diff Show More