Compare commits

...

46 Commits

Author SHA1 Message Date
Awni Hannun
6a9b584f3d patch bump (#1131) 2024-05-16 20:51:33 -07:00
Awni Hannun
81dd33af66 allow conversion to dlpack (#1120) 2024-05-16 16:11:37 -07:00
Awni Hannun
8b76571896 Fix extensions (#1126)
* fix extensions

* title

* enable circle

* fix nanobind tag

* fix bug in doc

* try to fix config

* typo
2024-05-16 15:36:25 -07:00
Angelos Katharopoulos
e78a6518fa Block sparse qmm (#1124) 2024-05-16 15:24:14 -07:00
Awni Hannun
1873ffda01 Detect metal version and propagate correctly for JIT (#1109)
* detect metal version and propagate correctly for JIT

* remove softmax

* fix versions
2024-05-15 17:42:09 -07:00
Jacket
c417e42116 [Fix] minor typo in default argument for argpartition's "axis" parameter (#1125)
According to the document, argpartition's axis parameter can be None, but due to a previous typo it can't really accepts a None value.
2024-05-15 15:25:25 -07:00
Jagrit Digani
358e1fd6ab Fused GEMM (#1123)
* Basic gemm working

* Update addmm

* Clear out steel_gemm and steel_addmm kernels

* Fuse and clear out gather gemm

* Update objc releases
2024-05-15 10:30:41 -07:00
Awni Hannun
631dfbe673 fix scatter index bug (#1122) 2024-05-14 15:04:58 -07:00
Cheng
56a4eaed72 Pass missing stream arg in array.flatten (#1111) 2024-05-14 06:50:16 -07:00
Cheng
bf925d9dc7 Move args in conv_general (#1118)
Also fix a typo that padding_lo is passed as padding_hi.
2024-05-14 06:50:09 -07:00
Cheng
1a7ed5dcb6 Fill vector with constructor instead of fill_n (#1113) 2024-05-14 06:28:55 -07:00
Cheng
5be5daa6ef Use compiled function in Sigmoid module (#1116) 2024-05-14 06:25:57 -07:00
Cheng
60cb11764e Use correct module type in quantized.py (#1115) 2024-05-14 06:25:42 -07:00
Cheng
cbd5445ea7 The tile op does not accept None as reps (#1117) 2024-05-14 06:25:25 -07:00
Cheng
2c7e9b5158 Add missing docs for some ops (#1110) 2024-05-14 06:09:05 -07:00
Mike Drob
2263e4b279 Experiment with medium machines for CI (#1000) 2024-05-13 19:40:19 -07:00
Awni Hannun
863039da4c Allow scatter type exception to be caught by checking in op (#1077)
* allow exception to be caught in main thread

* only for gpu

* more detailed scatter error
2024-05-13 17:43:53 -07:00
Awni Hannun
7178ac0111 No CPU option for binary minimization (#1105)
* no cpu build option

* docs

* fix
2024-05-13 16:08:11 -07:00
Ravindra R. Jaju
e7f9710499 Fix typo in a variable name in example code. (#1104)
* Fix typo in a variable name in example code.

* Rename df2dx2 to d2fdx2 - the appropriate naming for the second derivative

* Update CONTRIBUTING.md - add needed python packages, and a virtual-env hint

* Revert "Fix typo in a variable name in example code."

This reverts commit bc10a17534.

* Rename df2dx2 to d2fdx2
2024-05-13 06:04:23 -07:00
Max-Heinrich Laves
ff4223904d Conv3d (#993)
* added conv3d

added conv3d

implemented explicit_gemm_conv_ND_cpu and bounds checks for slow_conv_3D

* incorporated reviewer comments

* fixed test

* reduced tensor shapes in test for conv3d

* Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

Reviewer suggestion
2024-05-11 06:15:02 -07:00
Awni Hannun
a9f80d60f6 improve error messaging in eval (#1101) 2024-05-10 10:04:07 -07:00
Alex Barron
2e158cf6d0 Add conjugate operator (#1100)
* cpu and gpu impl

* add mx.conj and array.conj()

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-05-10 07:22:20 -07:00
Awni Hannun
8bd6bfa4b5 version (#1099) 2024-05-09 17:52:39 -07:00
Awni Hannun
8b1906abd0 Add compiler flags to disable safetensors and gguf (#1098)
* with docs

* nit
2024-05-09 17:39:44 -07:00
Awni Hannun
06375e6605 Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders

* fix race
2024-05-09 16:21:02 -07:00
Awni Hannun
b21242faf1 Allow unary ops to accept array like (#1093) 2024-05-09 09:36:02 -07:00
Rahul Yedida
cc05a281c4 Added ArcTan2 operation (#1079)
* Added ArcTan2 operation

* Cleanup, bug fixes from code review

* Minor cleanup, fixed Linux tests
2024-05-08 08:35:15 -07:00
Jagrit Digani
fe96ceee66 Update block offset adjustment to be in size_t (#1087) 2024-05-08 08:10:23 -07:00
Awni Hannun
9814a2ae12 fix conversion to array (#1070) 2024-05-06 16:02:49 -07:00
Shubham
6992498e7a add keyword positonal (#1081) 2024-05-06 07:18:49 -07:00
Awni Hannun
21623156a3 Reset peak memory (#1074)
* reset peak memory

* fix linux

* nits in docs
2024-05-03 17:12:51 -07:00
Nripesh Niketan
79c859e2e0 feat: implement clip_grad_norm (#1043)
* feat: implement `clip_grad_norm`

* pre-commit

* Add test for clip_grad_norm function in test_optimizers.py

* small fixes

* fix

* lint

* Update tree_reduce

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/mlx/utils.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Refactor clip_grad_norm function to include documentation and improve readability

* format docstring

* Add acknowlegements

* text wrap

* pre-commit

* nits in docs

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
2024-05-03 09:07:02 -07:00
Awni Hannun
b00ac960b4 change initial memory limits and add memory size to device info (#1064) 2024-05-03 06:50:15 -07:00
Awni Hannun
02a9fc7bfa Patch bump (#1067)
* version

* use 0.12.2
2024-05-02 16:37:31 -07:00
Jagrit Digani
f390957685 Block sparse mm (#1058) 2024-05-02 14:03:58 -07:00
Angelos Katharopoulos
17f57df797 Improvements in the quantizer and dequantization kernel (#1061) 2024-05-01 18:19:11 -07:00
Awni Hannun
7f7b9662ea Fix leak for multi-output primitives which are never detached (#1059)
* fix multi output leak

* ignore arrays that will be detached

* add some comments

* stray print
2024-05-01 07:31:45 -07:00
Awni Hannun
19bef39f5c Add a mx.metal.device_info (#1060)
* device inof

* add variant

* fix linux

* fix doc
2024-04-30 15:47:27 -07:00
Nripesh Niketan
a30e7ed2da feat: metal formatting and pre-commit bump (#1038)
* feat: metal formatting and pre-commit bump

* add guards

* update

* more guards

* more guards

* smakk fix

* Refactor instantiation of ternary types in ternary.metal

* fix scan.metal
2024-04-30 07:18:09 -07:00
Angelos Katharopoulos
8db7161c94 Bug fix in quantize (#1054) 2024-04-29 20:55:04 -07:00
Awni Hannun
09f1777896 fix slice update indexing (#1053) 2024-04-29 12:17:40 -07:00
Jacket
490c0c4fdc [Fix] expand axes for dimension with integer indices in mlx_slice_update (#1035)
* Not sure if this is correct

* Format

* Edit tests

* Add negative test

* Format

* add one more test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-04-29 07:57:28 -07:00
Rifur13
c4a471c99d Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-04-27 06:24:57 -07:00
Awni Hannun
86f495985b Add bitwise ops (#1037)
* bitwise ops

* fix tests
2024-04-26 22:03:42 -07:00
Awni Hannun
67d1894759 fix order device -> scheduler (#1039) 2024-04-26 13:46:41 -07:00
Awni Hannun
5bfe89bdb1 Cpp docs (#1036)
* start of C++ docs

* fix stream doc

* only include ops for now
2024-04-26 12:56:05 -07:00
162 changed files with 11094 additions and 4634 deletions

View File

@@ -49,11 +49,6 @@ jobs:
name: Run Python tests
command: |
python3 -m unittest discover python/tests -v
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3 -m pip install .
- run:
name: Build CPP only
command: |
@@ -69,7 +64,7 @@ jobs:
default: "15.2.0"
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:
@@ -101,11 +96,10 @@ jobs:
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
# TODO: Reenable when extension api becomes stable
# - run:
# name: Build example extension
# command: |
# cd examples/extensions && python3.11 -m pip install .
- run:
name: Build example extension
command: |
cd examples/extensions && python3.8 -m pip install .
- store_test_results:
path: test-results
- run:
@@ -132,7 +126,7 @@ jobs:
default: ""
macos:
xcode: << parameters.xcode_version >>
resource_class: macos.m1.large.gen1
resource_class: macos.m1.medium.gen1
steps:
- checkout
- run:

View File

@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.3
rev: v18.1.4
hooks:
- id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.3.0
rev: 24.4.2
hooks:
- id: black
- repo: https://github.com/pycqa/isort

View File

@@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.

View File

@@ -15,12 +15,15 @@ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.12.0)
set(MLX_VERSION 0.13.1)
endif()
# --------------------- Processor tests -------------------------
@@ -84,9 +87,11 @@ elseif (MLX_BUILD_METAL)
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
set(MLX_METAL_VERSION METAL_3_1)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
set(MLX_METAL_VERSION METAL_3_0)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
@@ -94,7 +99,7 @@ elseif (MLX_BUILD_METAL)
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
PATCH_COMMAND patch -N -i ${METAL_CPP_PATCH} || true
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
)
FetchContent_MakeAvailable(metal_cpp)
@@ -108,51 +113,57 @@ elseif (MLX_BUILD_METAL)
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions(${MLX_METAL_VERSION})
endif()
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
if (MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
endif()
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
set(MLX_BUILD_ACCELERATE OFF)
if(${CMAKE_HOST_APPLE})
# The blas shipped in macOS SDK is not supported, search homebrew for
# openblas instead.
set(BLA_VENDOR OpenBLAS)
set(LAPACK_ROOT "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
endif()
# Search and link with lapack.
find_package(LAPACK REQUIRED)
if (NOT LAPACK_FOUND)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h
/usr/include
/usr/local/include
/usr/local/opt/openblas/include)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
if (NOT BLAS_FOUND)
message(FATAL_ERROR "Must have BLAS installed")
endif()
# TODO find a cleaner way to do this
find_path(BLAS_INCLUDE_DIRS cblas.h
/usr/include
/usr/local/include
$ENV{BLAS_HOME}/include)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)

View File

@@ -0,0 +1,123 @@
import argparse
import math
import os
import subprocess
import time
import mlx.core as mx
import numpy as np
import torch
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 10
N_iter_bench = 100
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def make_mx_conv_1D(strides=1, padding=0, groups=1):
def mx_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
return mx_conv_1D
def make_pt_conv_1D(strides=1, padding=0, groups=1):
@torch.no_grad()
def pt_conv_1D(a, b):
ys = []
for _ in range(N_iter_func):
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
return pt_conv_1D
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
scale = 1.0 / math.sqrt(wH * C)
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
torch.mps.synchronize()
f_mx = make_mx_conv_1D(strides, padding, groups)
f_pt = make_pt_conv_1D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv1d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 1))
out_pt = out_pt.numpy(force=True)
atol = 2e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run conv benchmarks")
dtypes = ("float32",)
shapes = (
(4, 32, 32, 5, 32, 1, 2, 1),
(4, 32, 32, 5, 32, 1, 2, 2),
(4, 32, 32, 5, 32, 1, 2, 4),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 8),
(4, 32, 32, 5, 32, 1, 2, 16),
(4, 32, 32, 5, 32, 1, 2, 32),
(4, 32, 256, 5, 512, 1, 2, 2),
(4, 32, 256, 5, 512, 1, 2, 128),
(4, 32, 256, 5, 512, 1, 2, 256),
)
for dtype in dtypes:
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
for N, iH, C, wH, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, iH, C, wH, O, strides, padding, np_dtype, groups
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

50
docs/Doxyfile Normal file
View File

@@ -0,0 +1,50 @@
################################################################################
# Primary project setup. #
################################################################################
PROJECT_NAME = "MLX"
OUTPUT_DIRECTORY = build
XML_OUTPUT = xml
HTML_OUTPUT = html
STRIP_FROM_PATH = ../
INPUT = ../mlx
FILE_PATTERNS = *.h
EXCLUDE_PATTERNS = */private/*
CREATE_SUBDIRS = NO
FULL_PATH_NAMES = YES
RECURSIVE = YES
GENERATE_HTML = YES
GENERATE_LATEX = NO
GENERATE_XML = YES
XML_PROGRAMLISTING = YES
################################################################################
# Doxygen preprocessor / parser control. #
################################################################################
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
EXPAND_ONLY_PREDEF = NO
SKIP_FUNCTION_MACROS = NO
################################################################################
# Compound extraction control. #
################################################################################
EXTRACT_ALL = YES
EXTRACT_PACKAGE = YES
EXTRACT_STATIC = YES
CASE_SENSE_NAMES = NO
################################################################################
# Docstring control / customization. #
################################################################################
JAVADOC_AUTOBRIEF = YES
################################################################################
# Warning suppression. #
################################################################################
QUIET = YES
WARN_IF_UNDOCUMENTED = NO

View File

@@ -2,12 +2,16 @@
### Setup (do once)
Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html)
for example with `conda`:
Install Doxygen:
```
conda install sphinx
pip install sphinx-book-theme
brew install doxygen
```
Install Python packages:
```
pip install -r requirements.txt
```
### Build
@@ -15,7 +19,7 @@ pip install sphinx-book-theme
Build the docs from `mlx/docs/`
```
make html
doxygen && make html
```
View the docs by running a server in `mlx/docs/build/html/`:

3
docs/requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
sphinx
breathe
sphinx-book-theme

View File

@@ -22,6 +22,7 @@ extensions = [
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"breathe",
]
python_use_unqualified_type_names = True
@@ -33,6 +34,9 @@ intersphinx_mapping = {
"numpy": ("https://numpy.org/doc/stable/", None),
}
breathe_projects = {"mlx": "../build/xml"}
breathe_default_project = "mlx"
templates_path = ["_templates"]
html_static_path = ["_static"]
source_suffix = ".rst"

View File

@@ -3,4 +3,5 @@
Operations
==========
.. doxygengroup:: ops
:content-only:

View File

@@ -1,5 +1,5 @@
Developer Documentation
=======================
Custom Extensions in MLX
========================
You can extend MLX with custom operations on the CPU or GPU. This guide
explains how to do that with a simple example.
@@ -494,7 +494,7 @@ below.
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
@@ -503,11 +503,11 @@ below.
size_t nelem = out.size();
// Encode input arrays to kernel
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
// Encode output arrays to kernel
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@@ -531,7 +531,7 @@ below.
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
@@ -825,7 +825,7 @@ Let's look at a simple script and its results:
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correctness: {mx.all(c == 6.0).item()}")
print(f"c correct: {mx.all(c == 6.0).item()}")
Output:

View File

@@ -153,11 +153,16 @@ should point to the path to the built metal library.
- OFF
* - MLX_BUILD_METAL
- ON
* - MLX_BUILD_CPU
- ON
* - MLX_BUILD_PYTHON_BINDINGS
- OFF
* - MLX_METAL_DEBUG
- OFF
* - MLX_BUILD_SAFETENSORS
- ON
* - MLX_BUILD_GGUF
- ON
.. note::
@@ -176,10 +181,28 @@ should point to the path to the built metal library.
xcrun -sdk macosx --show-sdk-version
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
and `BUILD_SHARED_LIBS=ON`.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
```shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=ON \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF
```
Troubleshooting
^^^^^^^^^^^^^^^
Metal not found
~~~~~~~~~~~~~~~

View File

@@ -8,5 +8,7 @@ Linear Algebra
.. autosummary::
:toctree: _autosummary
inv
norm
qr
svd

View File

@@ -7,8 +7,10 @@ Metal
:toctree: _autosummary
is_available
device_info
get_active_memory
get_peak_memory
reset_peak_memory
get_cache_memory
set_memory_limit
set_cache_limit

View File

@@ -15,6 +15,7 @@ Layers
BatchNorm
Conv1d
Conv2d
Conv3d
Dropout
Dropout2d
Dropout3d

View File

@@ -10,6 +10,7 @@ Operations
abs
add
addmm
all
allclose
any
@@ -19,20 +20,28 @@ Operations
arcsin
arcsinh
arctan
arctan2
arctanh
argmax
argmin
argpartition
argsort
array_equal
as_strided
atleast_1d
atleast_2d
atleast_3d
broadcast_to
bitwise_and
bitwise_or
bitwise_xor
block_masked_mm
block_sparse_mm
broadcast_to
ceil
clip
concatenate
conj
conjugate
convolve
conv1d
conv2d
@@ -69,6 +78,8 @@ Operations
isnan
isneginf
isposinf
issubdtype
left_shift
less
less_equal
linspace
@@ -98,13 +109,16 @@ Operations
outer
partition
pad
power
prod
quantize
quantized_matmul
radians
reciprocal
remainder
repeat
reshape
right_shift
round
rsqrt
save

View File

@@ -1,5 +1,7 @@
.. _optimizers:
.. currentmodule:: mlx.optimizers
Optimizers
==========
@@ -34,3 +36,8 @@ model's parameters and the **optimizer state**.
optimizers/optimizer
optimizers/common_optimizers
optimizers/schedulers
.. autosummary::
:toctree: _autosummary
clip_grad_norm

View File

@@ -20,3 +20,4 @@ return python trees will be using the default python ``dict``, ``list`` and
tree_unflatten
tree_map
tree_map_with_path
tree_reduce

View File

@@ -89,8 +89,8 @@ void automatic_differentiation() {
// dfdx is 2 * x
// Get the second derivative by composing grad with grad
auto df2dx2 = grad(grad(fn))(x);
// df2dx2 is 2
auto d2fdx2 = grad(grad(fn))(x);
// d2fdx2 is 2
}
int main() {

View File

@@ -1,5 +1,5 @@
## Build the extensions
## Build
```
pip install -e .
@@ -16,3 +16,9 @@ And then run:
```
python setup.py build_ext -j8 --inplace
```
## Test
```
python test.py
`

View File

@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto compute_encoder = d.get_command_encoder(s.index);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
@@ -266,11 +266,11 @@ void Axpby::eval_gpu(
size_t nelem = out.size();
// Encode input arrays to kernel
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(y, 1);
// Encode output arrays to kernel
set_array_buffer(compute_encoder, out, 2);
compute_encoder.set_output_array(out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@@ -296,7 +296,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
#else // Metal is not available

View File

@@ -33,7 +33,7 @@ array axpby(
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){};
: Primitive(stream), alpha_(alpha), beta_(beta) {};
/**
* A primitive must know how to evaluate itself on the CPU/GPU

View File

@@ -19,7 +19,7 @@ template <typename T>
uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
out[index] =
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
@@ -31,30 +31,30 @@ template <typename T>
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
out[index] =
out[index] =
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
}
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] \
[[kernel]] void axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \
axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);

View File

@@ -2,4 +2,4 @@
import mlx.core as mx
from .mlx_sample_extensions import *
from ._ext import axpby

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4

View File

@@ -0,0 +1,10 @@
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correct: {mx.all(c == 6.0).item()}")

View File

@@ -19,11 +19,16 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
if (MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
else()
elseif(MLX_BUILD_CPU)
target_sources(
mlx
PRIVATE

View File

@@ -14,7 +14,7 @@ class Buffer {
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr){};
Buffer(void* ptr) : ptr_(ptr) {};
// Get the raw data pointer from the buffer
void* raw_ptr();

View File

@@ -1,5 +1,4 @@
// Copyright © 2023-2024 Apple Inc.
#include <functional>
#include "mlx/array.h"
@@ -167,6 +166,39 @@ void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
array::~array() {
if (array_desc_ == nullptr) {
return;
}
// Ignore arrays that will be detached
if (status() != array::Status::unscheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
// If all siblings have siblings.size() references except
// the one we are currently destroying (which has siblings.size() + 1)
// then there are no more external references
do_detach &= (array_desc_.use_count() == (n + 1));
for (auto& s : siblings()) {
do_detach &= (s.array_desc_.use_count() == n);
if (!do_detach) {
break;
}
}
if (do_detach) {
for (auto& s : siblings()) {
for (auto& ss : s.siblings()) {
ss.array_desc_ = nullptr;
}
s.array_desc_->siblings.clear();
}
}
}
}
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;

View File

@@ -114,6 +114,15 @@ class array {
return array_desc_->strides;
};
/**
* Get the stride of the corresponding dimension.
*
* This function supports negative indexing and provides
* bounds checking. */
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
};
/** Get the arrays data type. */
Dtype dtype() const {
return array_desc_->dtype;
@@ -200,7 +209,7 @@ class array {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d){};
: buffer(buffer), d(d) {};
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
@@ -252,22 +261,16 @@ class array {
return array_desc_->siblings;
};
/** The array's siblings. */
std::vector<array>& siblings() {
return array_desc_->siblings;
};
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
array_desc_->position = position;
}
/** The i-th output of the array's primitive. */
const array& output(int i) const {
if (i == array_desc_->position) {
return *this;
} else if (i < array_desc_->position) {
return siblings()[i];
} else {
return siblings()[i + 1];
}
};
/** The outputs of the array's primitive (i.e. this array and
* its siblings) in the order the primitive expects. */
std::vector<array> outputs() const {
@@ -377,6 +380,8 @@ class array {
array_desc_ = other.array_desc_;
}
~array();
private:
// Initialize the arrays data
template <typename It>

View File

@@ -32,9 +32,12 @@ DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
@@ -192,6 +195,26 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
}
}
void ArcTan2::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) {
if (a.is_donatable()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
int size = a.data_size();
vvatan2f(out.data<float>(), a.data<float>(), b.data<float>(), &size);
} else {
eval(inputs, out);
}
}
void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];

View File

@@ -37,6 +37,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp

View File

@@ -236,4 +236,82 @@ void Subtract::eval(const std::vector<array>& inputs, array& out) {
binary(a, b, out, detail::Subtract());
}
void BitwiseBinary::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto dispatch_type = [&a, &b, &out](auto op) {
switch (out.dtype()) {
case bool_:
binary_op<bool>(a, b, out, op);
case uint8:
binary_op<uint8_t>(a, b, out, op);
break;
case uint16:
binary_op<uint16_t>(a, b, out, op);
break;
case uint32:
binary_op<uint32_t>(a, b, out, op);
break;
case uint64:
binary_op<uint64_t>(a, b, out, op);
break;
case int8:
binary_op<int8_t>(a, b, out, op);
break;
case int16:
binary_op<int16_t>(a, b, out, op);
break;
case int32:
binary_op<int32_t>(a, b, out, op);
break;
case int64:
binary_op<int64_t>(a, b, out, op);
break;
default:
throw std::runtime_error(
"[BitwiseBinary::eval_cpu] Type not supported");
break;
}
};
switch (op_) {
case BitwiseBinary::And:
dispatch_type(detail::BitwiseAnd());
break;
case BitwiseBinary::Or:
dispatch_type(detail::BitwiseOr());
break;
case BitwiseBinary::Xor:
dispatch_type(detail::BitwiseXor());
break;
case BitwiseBinary::LeftShift:
dispatch_type(detail::LeftShift());
break;
case BitwiseBinary::RightShift:
dispatch_type(detail::RightShift());
break;
}
}
void ArcTan2::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
const auto& a = inputs[0];
const auto& b = inputs[1];
if (out.dtype() == float32) {
binary_op<float>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == float16) {
binary_op<float16_t>(a, b, out, detail::ArcTan2());
} else if (out.dtype() == bfloat16) {
binary_op<bfloat16_t>(a, b, out, detail::ArcTan2());
} else if (issubdtype(out.dtype(), inexact)) {
std::ostringstream err;
err << "[arctan2] Does not support " << out.dtype();
throw std::invalid_argument(err.str());
} else {
throw std::invalid_argument(
"[arctan2] Cannot compute inverse tangent for arrays"
" with non floating point type.");
}
}
} // namespace mlx::core

View File

@@ -0,0 +1,347 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (!in.flags().row_contiguous) {
// Just ensuring that inputs[0] came from the ops which would ensure the
// input is row contiguous.
throw std::runtime_error(
"AsStrided must be used with row contiguous arrays only.");
}
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
c *= shape_[j];
}
auto flags = in.flags();
// TODO: Compute the contiguous flag in a better way cause now we are
// unnecessarily strict.
flags.contiguous = row_contiguous || col_contiguous;
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
// There is no easy way to compute the actual data size so we use out.size().
// The contiguous flag will almost certainly not be set so no code should
// rely on data_size anyway.
size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
numel *= inputs[0].shape(ax);
}
if (inverted_) {
numel = 1.0 / numel;
}
switch (out.dtype()) {
case bool_:
*out.data<bool>() = static_cast<bool>(numel);
break;
case uint8:
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
break;
case uint16:
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
break;
case uint32:
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
break;
case uint64:
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
break;
case int8:
*out.data<int8_t>() = static_cast<int8_t>(numel);
break;
case int16:
*out.data<int16_t>() = static_cast<int16_t>(numel);
break;
case int32:
*out.data<int32_t>() = static_cast<int32_t>(numel);
break;
case int64:
*out.data<int64_t>() = static_cast<int64_t>(numel);
break;
case float16:
*out.data<float16_t>() = static_cast<float16_t>(numel);
break;
case float32:
*out.data<float>() = static_cast<float>(numel);
break;
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;
}
}
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
}
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
}
// Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
int N = out.shape(i);
if (j < shape.size() && shape[j] % N == 0) {
shape[j] /= N;
out_strides.push_back(shape[j] * strides[j]);
j += (shape[j] == 1);
} else if (N == 1) {
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
out_strides.push_back(out_strides.back());
} else {
copy_necessary = true;
break;
}
}
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
// For row contiguous reshapes:
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again.
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Split::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in_data_size) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
return std::pair<decltype(flags), size_t>{flags, data_size};
};
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
outputs[i].copy_shared_buffer(
in, in.strides(), new_flags, data_size, offset);
}
}
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
const array& in) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}
// Conditions for {row/col}_contiguous
// - array must be contiguous (no gaps)
// - underlying buffer size should have the same size as the array
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
// - vectors are both row and col contiguous (hence if both row/col are
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -38,11 +38,15 @@ void slow_conv_1D(
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_C = in.strides()[2];
@@ -57,35 +61,36 @@ void slow_conv_1D(
for (int n = 0; n < N; ++n) {
for (int oh = 0; oh < oH; ++oh) {
for (int o = 0; o < O; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
for (int wh = 0; wh < wH; ++wh) {
const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
int wh_flip = flip ? (wH - wh - 1) : wh;
int ih = oh * wt_strides[0] - padding[0] + wh_flip * wt_dilation[0];
auto ih_div = std::div(ih, in_dilation[0]);
auto ih_div = std::div(ih, in_dilation[0]);
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = 0; c < C; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[c * wt_stride_C]);
} // c
if (ih >= 0 && ih < iH && ih_div.rem == 0) {
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(
in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
static_cast<float>(wt_ptr[(c % C_per_group) * wt_stride_C]);
} // c
} // ih check
} // wh
} // ih check
} // wh
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
} // o
} // g
} // oh
in_ptr += in_stride_N;
out_ptr += out_stride_N;
} // n
}
@@ -305,6 +310,296 @@ void slow_conv_2D(
} // n
}
template <typename T>
void slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
const T* st_wt_ptr = wt.data<T>();
const T* st_in_ptr = in.data<T>();
T* st_out_ptr = out.data<T>();
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iD = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iH = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int iW = 1 + in_dilation[2] * (in.shape(3) - 1); // Input spatial dim
const int oD = out.shape(1); // Output spatial dim
const int oH = out.shape(2); // Output spatial dim
const int oW = out.shape(3); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(4); // In channels
const int wD = wt.shape(1); // Weight spatial dim
const int wH = wt.shape(2); // Weight spatial dim
const int wW = wt.shape(3); // Weight spatial dim
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_D = in.strides()[1];
const size_t in_stride_H = in.strides()[2];
const size_t in_stride_W = in.strides()[3];
const size_t in_stride_C = in.strides()[4];
const size_t wt_stride_O = wt.strides()[0];
const size_t wt_stride_D = wt.strides()[1];
const size_t wt_stride_H = wt.strides()[2];
const size_t wt_stride_W = wt.strides()[3];
const size_t wt_stride_C = wt.strides()[4];
const size_t out_stride_N = out.strides()[0];
const size_t out_stride_D = out.strides()[1];
const size_t out_stride_H = out.strides()[2];
const size_t out_stride_W = out.strides()[3];
const size_t out_stride_O = out.strides()[4];
bool is_idil_one =
in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;
auto pt_conv_no_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int od,
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wd = 0; wd < wD; ++wd) {
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];
const T* wt_ptr_pt =
wt_ptr + wd * wt_stride_D + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + id * in_stride_D + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // ww
} // wh
} // wd
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];
int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);
int f_wgt_jump_d = std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
int f_wgt_jump_h = std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
int f_wgt_jump_w = std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];
int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];
std::vector<int> base_d(f_out_jump_d);
std::vector<int> base_h(f_out_jump_h);
std::vector<int> base_w(f_out_jump_w);
for (int i = 0; i < f_out_jump_d; ++i) {
int id_loop = i * wt_strides[0] - padding[0] + init_d;
int wd_base = 0;
while (wd_base < wD && id_loop % in_dilation[0] != 0) {
wd_base++;
id_loop += jump_d;
}
base_d[i] = wd_base;
}
for (int i = 0; i < f_out_jump_h; ++i) {
int ih_loop = i * wt_strides[1] - padding[1] + init_h;
int wh_base = 0;
while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
wh_base++;
ih_loop += jump_h;
}
base_h[i] = wh_base;
}
for (int j = 0; j < f_out_jump_w; ++j) {
int iw_loop = j * wt_strides[2] - padding[2] + init_w;
int ww_base = 0;
while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
ww_base++;
iw_loop += jump_w;
}
base_w[j] = ww_base;
}
auto pt_conv_all_checks = [&](const T* in_ptr,
const T* wt_ptr,
T* out_ptr,
int od,
int oh,
int ow) {
out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
int id_base = od * wt_strides[0] - padding[0];
int ih_base = oh * wt_strides[1] - padding[1];
int iw_base = ow * wt_strides[2] - padding[2];
int wd_base = base_d[od % f_out_jump_d];
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wd_flip = flip ? wD - wd - 1 : wd;
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int id = id_base + wd_flip * wt_dilation[0];
int ih = ih_base + wh_flip * wt_dilation[1];
int iw = iw_base + ww_flip * wt_dilation[2];
if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
iw < iW) {
const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
wh * wt_stride_H + ww * wt_stride_W;
int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
} // iD, ih, iw check
} // ww
} // wh
} // wd
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
};
int oD_border_0 = 0;
int oD_border_1 =
is_idil_one ? ((padding[0] + wt_strides[0] - 1) / wt_strides[0]) : oD;
int oD_border_2 = std::max(
oD_border_1, (iD + padding[0] - wD * wt_dilation[0]) / wt_strides[0]);
int oD_border_3 = oD;
int oH_border_0 = 0;
int oH_border_1 =
is_idil_one ? ((padding[1] + wt_strides[1] - 1) / wt_strides[1]) : oH;
int oH_border_2 = std::max(
oH_border_1, (iH + padding[1] - wH * wt_dilation[1]) / wt_strides[1]);
int oH_border_3 = oH;
int oW_border_0 = 0;
int oW_border_1 =
is_idil_one ? ((padding[2] + wt_strides[2] - 1) / wt_strides[2]) : oW;
int oW_border_2 = std::max(
oW_border_1, (iW + padding[2] - wW * wt_dilation[2]) / wt_strides[2]);
int oW_border_3 = oW;
for (int n = 0; n < N; ++n) {
// Case 1: od might put us out of bounds
for (int od = oD_border_0; od < oD_border_1; ++od) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
// Case 2: od in bounds
for (int od = oD_border_1; od < oD_border_2; ++od) {
// Case 2.1: oh might put us out of bounds
for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
// Case 2.2: oh in bounds
for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
// Case 2.2.1: ow might put us out of bounds
for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
// Case 2.2.2: ow in bounds
for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
// Case 2.2.3: ow might put us out of bounds
for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
// Case 2.3: oh might put us out of bounds
for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
// Case 3: od might put us out of bounds
for (int od = oD_border_2; od < oD_border_3; ++od) {
for (int oh = 0; oh < oH; ++oh) {
for (int ow = 0; ow < oW; ++ow) {
pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
} // ow
} // oh
} // od
st_in_ptr += in_stride_N;
st_out_ptr += out_stride_N;
} // n
}
void dispatch_slow_conv_1D(
const array& in,
const array& wt,
@@ -353,6 +648,30 @@ void dispatch_slow_conv_2D(
}
}
void dispatch_slow_conv_3D(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
if (in.dtype() == float32) {
return slow_conv_3D<float>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == float16) {
return slow_conv_3D<float16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else if (in.dtype() == bfloat16) {
return slow_conv_3D<bfloat16_t>(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
} else {
throw std::invalid_argument(
"[Convolution::eval] got unsupported data type.");
}
}
///////////////////////////////////////////////////////////////////////////////
// Explicit gemm conv
///////////////////////////////////////////////////////////////////////////////
@@ -366,11 +685,15 @@ void explicit_gemm_conv_1D_cpu(
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = in.shape(1); // Input spatial dim
const int C = in.shape(2); // Input channels
const int oH = out.shape(1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(2); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int groups = C / wt.shape(2);
const int C_per_group = wt.shape(2);
const int O_per_group = O / groups;
auto conv_dtype = float32;
// Pad input
@@ -402,6 +725,11 @@ void explicit_gemm_conv_1D_cpu(
in_padded.strides()[1],
in_padded.strides()[2]};
auto flags = in_padded.flags();
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
std::swap(strided_shape[2], strided_shape[3]);
std::swap(strided_strides[2], strided_strides[3]);
}
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
@@ -416,7 +744,19 @@ void explicit_gemm_conv_1D_cpu(
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
if (groups > 1) {
// Transpose the last two dimensions for grouped convolutions
array wt_transpose(
{wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
wt_transpose.copy_shared_buffer(
wt,
{wt.strides(0), wt.strides(2), wt.strides(1)},
wt.flags(),
wt.size(),
0);
gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
copy(wt_transpose, gemm_wt, CopyType::General);
} else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
@@ -428,27 +768,29 @@ void explicit_gemm_conv_1D_cpu(
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
for (int g = 0; g < groups; ++g) {
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O_per_group, // N
C_per_group * wH, // K
1.0f, // alpha
in_strided.data<float>() + g * C_per_group * wH, // A
wH * C, // lda
gemm_wt.data<float>() + g * O_per_group * C_per_group * wH, // B
wH * C_per_group, // ldb
0.0f, // beta
gemm_out.data<float>() + g * O_per_group, // C
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
}
@@ -554,6 +896,131 @@ void explicit_gemm_conv_2D_cpu(
}
}
void explicit_gemm_conv_ND_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation) {
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const auto iDim = std::vector<int>(
in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
const auto oDim = std::vector<int>(
out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(-1); // In channels
const auto wDim = std::vector<int>(
wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
auto conv_dtype = float32;
// Pad input
std::vector<int> padded_shape(in.shape().size());
padded_shape.front() = N;
for (size_t i = 0; i < iDim.size(); i++) {
padded_shape[i + 1] = iDim[i] + 2 * padding[i];
}
padded_shape.back() = C;
array in_padded(padded_shape, conv_dtype, nullptr, {});
// Fill with zeros
copy(array(0, conv_dtype), in_padded, CopyType::Scalar);
// Pick input slice from padded
size_t data_offset = 0;
for (size_t i = 0; i < padding.size(); i++) {
data_offset += padding[i] * in_padded.strides()[i + 1];
}
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
in_padded_slice.copy_shared_buffer(
in_padded,
in_padded.strides(),
in_padded.flags(),
in_padded_slice.size(),
data_offset);
// Copy input values into the slice
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2);
strided_shape.front() = N;
for (size_t i = 0; i < oDim.size(); i++) {
strided_shape[i + 1] = oDim[i];
}
for (size_t i = 0; i < wDim.size(); i++) {
strided_shape[i + 1 + oDim.size()] = wDim[i];
}
strided_shape.back() = C;
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2);
strided_strides[0] = in_padded.strides()[0];
for (size_t i = 0; i < wt_strides.size(); i++) {
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
}
for (size_t i = 1; i < in_padded.strides().size(); i++) {
strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
}
auto flags = in_padded.flags();
array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
in_strided_view.copy_shared_buffer(
in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view
std::vector<int> strided_reshape = {N, C};
for (const auto& o : oDim) {
strided_reshape[0] *= o;
}
for (const auto& w : wDim) {
strided_reshape[1] *= w;
}
array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
copy(in_strided_view, in_strided, CopyType::General);
// Check wt dtype and prepare
auto gemm_wt = wt;
auto gemm_out = out;
if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
auto ctype =
wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
gemm_wt = array(wt.shape(), float32, nullptr, {});
copy(wt, gemm_wt, ctype);
}
if (out.dtype() != float32) {
gemm_out = array(out.shape(), float32, nullptr, {});
gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes()));
}
// Perform gemm
cblas_sgemm(
CblasRowMajor,
CblasNoTrans, // no trans A
CblasTrans, // transB
strided_reshape[0], // M
O, // N
strided_reshape[1], // K
1.0f, // alpha
in_strided.data<float>(),
strided_reshape[1], // lda
gemm_wt.data<float>(),
strided_reshape[1], // ldb
0.0f, // beta
gemm_out.data<float>(),
O // ldc
);
// Copy results if needed
if (out.dtype() != float32) {
copy(gemm_out, out, CopyType::Vector);
}
}
///////////////////////////////////////////////////////////////////////////////
// Conv routing
///////////////////////////////////////////////////////////////////////////////
@@ -589,6 +1056,19 @@ void conv_2D_cpu(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
void conv_3D_cpu(
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip) {
return dispatch_slow_conv_3D(
in, wt, out, padding, wt_strides, wt_dilation, in_dilation, flip);
}
} // namespace
void Convolution::eval(const std::vector<array>& inputs, array& out) {
@@ -597,8 +1077,20 @@ void Convolution::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto& wt = inputs[1];
// 3D convolution
if (in.ndim() == (3 + 2)) {
return conv_3D_cpu(
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_);
}
// 2D convolution
if (in.ndim() == (2 + 2)) {
else if (in.ndim() == (2 + 2)) {
return conv_2D_cpu(
in,
wt,

View File

@@ -34,6 +34,7 @@ DEFAULT(ArcCosh)
DEFAULT(ArcSin)
DEFAULT(ArcSinh)
DEFAULT(ArcTan)
DEFAULT(ArcTan2)
DEFAULT(ArcTanh)
DEFAULT(ArgPartition)
DEFAULT(ArgReduce)
@@ -42,9 +43,12 @@ DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)

View File

@@ -2,7 +2,6 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
@@ -93,12 +92,4 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
inverse_impl(inputs[0], output);
}
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::inv(a, stream())}, {ax}};
}
} // namespace mlx::core

View File

@@ -11,7 +11,7 @@ GCC=$2
SRCDIR=$3
CLANG=$4
if [ $CLANG = "TRUE" ]; then
if [ "$CLANG" = "TRUE" ]; then
read -r -d '' INCLUDES <<- EOM
#include <cmath>
#include <complex>

View File

@@ -190,4 +190,91 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
}
}
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockSparseMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);
if (M == 0 || N == 0) {
return;
}
if (K == 0) {
std::memset(static_cast<void*>(out.data<float>()), 0, out.nbytes());
return;
}
// Get batch dims
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = M * N;
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size();
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();
for (int i = 0; i < batch_size_out; i++) {
// Get index
uint32_t indx_A = lhs_indices_ptr[elem_to_loc(i, lhs_indices)];
uint32_t indx_B = rhs_indices_ptr[elem_to_loc(i, rhs_indices)];
cblas_sgemm(
CblasRowMajor,
a_transposed ? CblasTrans : CblasNoTrans, // transA
b_transposed ? CblasTrans : CblasNoTrans, // transB
M,
N,
K,
1.0f, // alpha
a.data<float>() + elem_to_loc(indx_A, batch_shape_A, batch_strides_A),
lda,
b.data<float>() + elem_to_loc(indx_B, batch_shape_B, batch_strides_B),
ldb,
0.0f, // beta
out.data<float>() + matrix_stride_out * i,
out.shape(-1) // ldc
);
}
}
} // namespace mlx::core

View File

@@ -161,6 +161,13 @@ struct ArcTan {
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return std::atan2(y, x);
};
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
@@ -202,6 +209,12 @@ struct Ceil {
};
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return std::conj(x);
}
};
struct Cos {
template <typename T>
T operator()(T x) {
@@ -606,4 +619,39 @@ struct Select {
}
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
} // namespace mlx::core::detail

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
@@ -113,61 +113,6 @@ void AsType::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, ctype);
}
void AsStrided::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (!in.flags().row_contiguous) {
// Just ensuring that inputs[0] came from the ops which would ensure the
// input is row contiguous.
throw std::runtime_error(
"AsStrided must be used with row contiguous arrays only.");
}
// Compute the flags given the shape and strides
bool row_contiguous = true, col_contiguous = true;
size_t r = 1, c = 1;
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
r *= shape_[i];
c *= shape_[j];
}
auto flags = in.flags();
// TODO: Compute the contiguous flag in a better way cause now we are
// unnecessarily strict.
flags.contiguous = row_contiguous || col_contiguous;
flags.row_contiguous = row_contiguous;
flags.col_contiguous = col_contiguous;
// There is no easy way to compute the actual data size so we use out.size().
// The contiguous flag will almost certainly not be set so no code should
// rely on data_size anyway.
size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
if (out.size() == 0) {
out.set_data(nullptr);
return;
}
std::vector<size_t> strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
}
auto flags = in.flags();
if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false;
}
out.copy_shared_buffer(in, strides, flags, in.data_size());
}
void Ceil::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -203,9 +148,15 @@ void Concatenate::eval(const std::vector<array>& inputs, array& out) {
}
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
void Conjugate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
const auto& in = inputs[0];
if (out.dtype() == complex64) {
unary_fp(in, out, detail::Conjugate());
} else {
throw std::invalid_argument(
"[conjugate] conjugate must be called on complex input.");
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) {
@@ -232,81 +183,6 @@ void Cosh::eval(const std::vector<array>& inputs, array& out) {
}
}
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]);
}
}
void Depends::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]);
}
}
void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
double numel = 1;
for (auto ax : axes_) {
numel *= inputs[0].shape(ax);
}
if (inverted_) {
numel = 1.0 / numel;
}
switch (out.dtype()) {
case bool_:
*out.data<bool>() = static_cast<bool>(numel);
break;
case uint8:
*out.data<uint8_t>() = static_cast<uint8_t>(numel);
break;
case uint16:
*out.data<uint16_t>() = static_cast<uint16_t>(numel);
break;
case uint32:
*out.data<uint32_t>() = static_cast<uint32_t>(numel);
break;
case uint64:
*out.data<uint64_t>() = static_cast<uint64_t>(numel);
break;
case int8:
*out.data<int8_t>() = static_cast<int8_t>(numel);
break;
case int16:
*out.data<int16_t>() = static_cast<int16_t>(numel);
break;
case int32:
*out.data<int32_t>() = static_cast<int32_t>(numel);
break;
case int64:
*out.data<int64_t>() = static_cast<int64_t>(numel);
break;
case float16:
*out.data<float16_t>() = static_cast<float16_t>(numel);
break;
case float32:
*out.data<float>() = static_cast<float>(numel);
break;
case bfloat16:
*out.data<bfloat16_t>() = static_cast<bfloat16_t>(numel);
break;
case complex64:
*out.data<complex64_t>() = static_cast<complex64_t>(numel);
break;
}
}
void Erf::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -536,63 +412,6 @@ void RandomBits::eval(const std::vector<array>& inputs, array& out) {
}
}
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
const array& in,
const array& out) {
// Special case for empty arrays or row contiguous arrays
if (in.size() == 0 || in.flags().row_contiguous) {
return {false, out.strides()};
}
// Special case for scalars
if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0);
return {false, out_strides};
}
// Firstly let's collapse all the contiguous dimensions of the input
auto [shape, _strides] = collapse_contiguous_dims(in);
auto& strides = _strides[0];
// If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check.
std::vector<size_t> out_strides;
bool copy_necessary = false;
int j = 0;
for (int i = 0; i < out.ndim(); i++) {
int N = out.shape(i);
if (j < shape.size() && shape[j] % N == 0) {
shape[j] /= N;
out_strides.push_back(shape[j] * strides[j]);
j += (shape[j] == 1);
} else if (N == 1) {
// i > 0 because otherwise j < shape.size() && shape[j] % 1 == 0
out_strides.push_back(out_strides.back());
} else {
copy_necessary = true;
break;
}
}
return {copy_necessary, out_strides};
}
void Reshape::shared_buffer_reshape(
const array& in,
const std::vector<size_t>& out_strides,
array& out) {
auto flags = in.flags();
if (flags.row_contiguous) {
// For row contiguous reshapes:
// - Shallow copy the buffer
// - If reshaping into a vector (all singleton dimensions except one) it
// becomes col contiguous again.
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
void Reshape::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -663,49 +482,6 @@ void Sinh::eval(const std::vector<array>& inputs, array& out) {
}
}
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
const array& in) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
void Slice::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (out.size() == 0) {
@@ -737,18 +513,6 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
}
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
if (out.size() == 0) {
@@ -786,58 +550,6 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
/* CopyType ctype = */ CopyType::GeneralGeneral);
}
void Split::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto compute_new_flags = [](const auto& shape,
const auto& strides,
size_t in_data_size,
auto flags) {
size_t data_size = 1;
size_t f_stride = 1;
size_t b_stride = 1;
flags.row_contiguous = true;
flags.col_contiguous = true;
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1;
flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
f_stride *= shape[i];
b_stride *= shape[ri];
if (strides[i] > 0) {
data_size *= shape[i];
}
}
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in_data_size) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
return std::pair<decltype(flags), size_t>{flags, data_size};
};
std::vector<int> indices(1, 0);
indices.insert(indices.end(), indices_.begin(), indices_.end());
for (int i = 0; i < indices.size(); i++) {
size_t offset = indices[i] * in.strides()[axis_];
auto [new_flags, data_size] = compute_new_flags(
outputs[i].shape(), in.strides(), in.data_size(), in.flags());
outputs[i].copy_shared_buffer(
in, in.strides(), new_flags, data_size, offset);
}
}
void Square::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -854,11 +566,6 @@ void Sqrt::eval(const std::vector<array>& inputs, array& out) {
}
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]);
}
void Tan::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
@@ -883,38 +590,4 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
}
}
void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim());
auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]];
}
// Conditions for {row/col}_contiguous
// - array must be contiguous (no gaps)
// - underlying buffer size should have the same size as the array
// - cumulative product of shapes is equal to the strides (we can ignore axes
// with size == 1)
// - in the forward direction (column contiguous)
// - in the reverse direction (row contiguous)
// - vectors are both row and col contiguous (hence if both row/col are
// true, they stay true)
auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
f_stride *= out.shape(i);
flags.row_contiguous &=
(out_strides[ri] == b_stride || out.shape(ri) == 1);
b_stride *= out.shape(ri);
}
}
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
}
} // namespace mlx::core

View File

@@ -192,7 +192,7 @@ void _qmm_dispatch_typed(
}
void _qmm_dispatch(
array out,
array& out,
const array& x,
const array& w,
const array& scales,
@@ -253,6 +253,81 @@ void _qmm_dispatch(
}
}
void _bs_qmm_dispatch(
array& out,
const array& x,
const array& w,
const array& scales,
const array& biases,
const array& lhs_indices,
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out.data<float>() + i * M * N,
x.data<float>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out.data<float16_t>() + i * M * N,
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out.data<bfloat16_t>() + i * M * N,
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
M,
N,
K,
bits,
group_size,
transposed_w);
break;
default:
throw std::invalid_argument(
"[quantized_matmul] only floating types are supported");
}
}
}
} // namespace
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
@@ -282,4 +357,45 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
void BlockSparseQMM::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
auto& scales_pre = inputs[2];
auto& biases_pre = inputs[3];
auto& lhs_indices = inputs[4];
auto& rhs_indices = inputs[5];
auto ensure_row_contiguous_last_dims = [](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
}
} // namespace mlx::core

View File

@@ -3,7 +3,6 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/lapack_helper.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
namespace mlx::core {
@@ -145,12 +144,4 @@ void SVD::eval(const std::vector<array>& inputs, std::vector<array>& outputs) {
svd_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
}
} // namespace mlx::core

View File

@@ -5,10 +5,16 @@ add_custom_command(
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
"-D${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h
kernels/unary.h
kernels/binary.h
kernels/bf16.h
kernels/erf.h
kernels/expm1f.h
kernels/utils.h
kernels/bf16_math.h
)
add_custom_target(

View File

@@ -140,10 +140,15 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) {
MetalAllocator::MetalAllocator()
: device_(device(mlx::core::Device::gpu).mtl_device()),
buffer_cache_(device_),
block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()),
gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()),
max_pool_size_(block_limit_) {}
buffer_cache_(device_) {
auto memsize = std::get<size_t>(device_info()["memory_size"]);
block_limit_ =
std::min(1.5 * device_->recommendedMaxWorkingSetSize(), 0.95 * memsize);
gc_limit_ = std::min(
static_cast<size_t>(0.95 * device_->recommendedMaxWorkingSetSize()),
block_limit_);
max_pool_size_ = block_limit_;
}
size_t MetalAllocator::set_cache_limit(size_t limit) {
std::swap(limit, max_pool_size_);
@@ -165,6 +170,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{nullptr};
}
// More helpful message if maximum buffer length is exceeded
if (size > device_->maxBufferLength()) {
std::ostringstream msg;
msg << "Attempting to allocate " << size << " bytes which is greater than"
<< " the maximum allowed buffer size of " << device_->maxBufferLength()
<< " bytes.";
throw std::runtime_error(msg.str());
}
// Align up memory
if (size > vm_page_size) {
size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size);
@@ -244,6 +258,9 @@ size_t get_active_memory() {
size_t get_peak_memory() {
return allocator().get_peak_memory();
}
void reset_peak_memory() {
allocator().reset_peak_memory();
}
size_t get_cache_memory() {
return allocator().get_cache_memory();
}

View File

@@ -62,6 +62,10 @@ class MetalAllocator : public allocator::Allocator {
size_t get_peak_memory() {
return peak_memory_;
};
void reset_peak_memory() {
std::unique_lock lk(mutex_);
peak_memory_ = 0;
};
size_t get_cache_memory() {
return buffer_cache_.cache_size();
};

View File

@@ -336,7 +336,7 @@ void Compiled::eval_gpu(
MTL::Size grid_dims(nthreads, 1, 1);
MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
@@ -347,7 +347,7 @@ void Compiled::eval_gpu(
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -59,7 +59,7 @@ void explicit_gemm_conv_ND_gpu(
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N};
@@ -89,6 +89,90 @@ void explicit_gemm_conv_ND_gpu(
/*copies = */ copies);
}
template <int N>
void explicit_gemm_conv_group_ND_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const MLXConvParams<N>& conv_params) {
const int groups = conv_params.groups;
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
// Get gemm shapes
const int implicit_M = out.size() / conv_params.O;
const int implicit_K = wt.size() / conv_params.O;
const int implicit_N = O_per_group;
int kernel_size = 1;
for (int i = 0; i < N; ++i) {
kernel_size *= conv_params.wS[i];
}
// Prepare unfolding array
std::vector<int> unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes()));
// Prepare unfolding kernel
std::ostringstream kname;
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
<< N;
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2);
// Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
compute_encoder.dispatchThreads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
array wt_view(
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer(
wt,
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
wt.flags(),
wt.size());
// Materialize
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm
std::vector<array> copies = {in_unfolded, wt_view, wt_transpose};
return steel_matmul_conv_groups(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_transpose,
/*c = */ out,
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*a_cols = */ implicit_K * groups,
/*b_cols = */ implicit_K,
/*out_cols = */ implicit_N * groups,
/*a_transposed = */ false,
/*b_transposed = */ true,
/* groups = */ groups,
/*copies = */ copies);
}
void conv_1D_gpu(
const Stream& s,
metal::Device& d,
@@ -99,6 +183,7 @@ void conv_1D_gpu(
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
int groups,
bool flip) {
// Make conv params
MLXConvParams<1> conv_params{
@@ -118,11 +203,15 @@ void conv_1D_gpu(
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2]},
/* const int groups = */ 1,
/* const int groups = */ groups,
/* const bool flip = */ flip};
// Direct to explicit gemm conv
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
if (groups > 1) {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
}
void slow_conv_2D_gpu(
@@ -158,7 +247,7 @@ void slow_conv_2D_gpu(
compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_gpu(
@@ -263,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void implicit_gemm_conv_2D_general_gpu(
@@ -423,7 +512,7 @@ void implicit_gemm_conv_2D_general_gpu(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
void winograd_conv_2D_gpu(
@@ -524,7 +613,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do input transform
@@ -552,7 +641,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
// Do batched gemm
@@ -600,7 +689,7 @@ void winograd_conv_2D_gpu(
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
}
}
@@ -670,6 +759,56 @@ void conv_2D_gpu(
}
}
void conv_3D_gpu(
const Stream& s,
metal::Device& d,
const array& in,
const array& wt,
array out,
const std::vector<int>& padding,
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
bool flip,
std::vector<array>& copies) {
// Make conv params
MLXConvParams<3> conv_params{
/* const int N = */ in.shape(0),
/* const int C = */ in.shape(4),
/* const int O = */ wt.shape(0),
/* const int iS[NDIM] = */ {in.shape(1), in.shape(2), in.shape(3)},
/* const int wS[NDIM] = */ {wt.shape(1), wt.shape(2), wt.shape(3)},
/* const int oS[NDIM] = */ {out.shape(1), out.shape(2), out.shape(3)},
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
/* const int kdil[NDIM] = */
{wt_dilation[0], wt_dilation[1], wt_dilation[2]},
/* const int idil[NDIM] = */
{in_dilation[0], in_dilation[1], in_dilation[2]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0],
in.strides()[1],
in.strides()[2],
in.strides()[3],
in.strides()[4]},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0],
wt.strides()[1],
wt.strides()[2],
wt.strides()[3],
wt.strides()[4]},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0],
out.strides()[1],
out.strides()[2],
out.strides()[3],
out.strides()[4]},
/* const int groups = */ 1,
/* const bool flip = */ flip,
};
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
}
} // namespace
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -694,8 +833,23 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
wt = arr_copy;
}
// 3D conv
if (out.ndim() == 5) {
conv_3D_gpu(
s,
d,
in,
wt,
out,
padding_,
kernel_strides_,
kernel_dilation_,
input_dilation_,
flip_,
copies);
}
// 2D conv
if (out.ndim() == 4) {
else if (out.ndim() == 4) {
conv_2D_gpu(
s,
d,
@@ -721,6 +875,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_);
}
// Throw error

View File

@@ -126,7 +126,7 @@ void copy_gpu_inplace(
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
@@ -135,7 +135,7 @@ void copy_gpu_inplace(
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -5,6 +5,8 @@
#include <filesystem>
#include <sstream>
#include <sys/sysctl.h>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
@@ -23,9 +25,18 @@ namespace {
// TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() {
#if defined METAL_3_1
return MTL::LanguageVersion3_1;
#else
return MTL::LanguageVersion3_0;
#endif
}
auto load_device() {
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0))
@@ -35,7 +46,6 @@ auto load_device() {
}
return device;
}
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
MTL::Device* device,
const char* path) {
@@ -114,6 +124,33 @@ MTL::Library* load_library(
} // namespace
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::dispatchThreads(
MTL::Size grid_dims,
MTL::Size group_dims) {
num_dispatches++;
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
@@ -128,9 +165,6 @@ Device::~Device() {
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& e : encoder_map_) {
(*e.second)->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
@@ -167,27 +201,26 @@ void Device::increment_command_buffer_ops(int index) {
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
}
if (bit == buffer_map_.end()) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
MTL::CommandBuffer* Device::new_command_buffer(int index) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
bit = buffer_map_.insert({index, {0, cb}}).first;
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
return buffer_map_.insert({index, {0, cb}}).first->second.second;
return bit->second.second;
}
void Device::commit_command_buffer(int index) {
@@ -198,25 +231,15 @@ void Device::commit_command_buffer(int index) {
}
void Device::end_encoding(int index) {
auto eit = encoder_map_.find(index);
if (eit != encoder_map_.end()) {
(*eit->second)->endEncoding();
(*eit->second)->release();
encoder_map_.erase(eit);
}
encoder_map_.erase(index);
}
CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
auto compute_encoder =
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
// Increment ref count so the buffer is not garbage collected
compute_encoder->retain();
eit = encoder_map_
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
.first;
eit =
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
}
return *(eit->second);
}
@@ -260,13 +283,17 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
NS::Error* error = nullptr;
auto mtl_lib = device_->newLibrary(ns_code, nullptr, &error);
auto options = MTL::CompileOptions::alloc()->init();
options->setFastMathEnabled(false);
options->setLanguageVersion(get_metal_version());
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
options->release();
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build metal library from source"
<< "\n";
msg << "[metal::Device] Unable to build metal library from source" << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@@ -285,8 +312,7 @@ MTL::Library* Device::get_library_(const MTL::StitchedLibraryDescriptor* desc) {
// Throw error if unable to compile library
if (!mtl_lib) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load build stitched metal library"
<< "\n";
msg << "[metal::Device] Unable to build stitched metal library" << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
@@ -344,7 +370,6 @@ MTL::Function* Device::get_function_(
}
mtl_func_consts->release();
desc->release();
return mtl_function;
}
@@ -513,11 +538,13 @@ MTL::ComputePipelineState* Device::get_kernel(
// Compile kernel to compute pipeline
auto mtl_linked_funcs = get_linked_functions_(linked_functions);
auto kernel = get_kernel_(kname, mtl_function, mtl_linked_funcs);
mtl_function->release();
mtl_linked_funcs->release();
// Add kernel to cache
kernel_map_.insert({kname, kernel});
return kernel;
}
@@ -558,4 +585,23 @@ void new_stream(Stream stream) {
}
}
std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0;
size_t length = sizeof(memsize);
sysctl(mib, 2, &memsize, &length, NULL, 0);
return {
{"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}};
}
} // namespace mlx::core::metal

View File

@@ -37,8 +37,10 @@ using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct CommandEncoder {
CommandEncoder(MTL::ComputeCommandEncoder* enc)
: enc(enc), concurrent(false){};
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
};
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -89,13 +91,25 @@ struct CommandEncoder {
}
}
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
ConcurrentContext start_concurrent() {
return ConcurrentContext(*this);
}
~CommandEncoder() {
enc->endEncoding();
enc->release();
}
private:
void maybe_split();
int num_dispatches{0};
MTL::CommandBuffer* cbuf;
MTL::ComputeCommandEncoder* enc;
bool concurrent;
bool concurrent{false};
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs;
};
@@ -112,7 +126,6 @@ class Device {
};
void new_queue(int index);
MTL::CommandBuffer* new_command_buffer(int index);
MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index);
void increment_command_buffer_ops(int index);

View File

@@ -97,7 +97,7 @@ void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto group_dims = MTL::Size(1, m, 1);
auto grid_dims = MTL::Size(batch, m, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });

View File

@@ -107,7 +107,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
}
// Launch grid
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -216,7 +216,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Collect all idx shapes and strides into one place
@@ -286,7 +286,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}

View File

@@ -39,7 +39,7 @@ set(
)
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only

View File

@@ -11,14 +11,14 @@ template <typename T>
out[index] = start + index * step;
}
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] \
[[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
constant const type& start, \
constant const type& step, \
device type* out, \
uint index [[thread_position_in_grid]]);
// clang-format off
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t)
@@ -29,4 +29,4 @@ instantiate_arange(int32, int32_t)
instantiate_arange(int64, int64_t)
instantiate_arange(float16, half)
instantiate_arange(float32, float)
instantiate_arange(bfloat16, bfloat16_t)
instantiate_arange(bfloat16, bfloat16_t) // clang-format on

View File

@@ -18,7 +18,8 @@ struct ArgMin {
static constexpr constant U init = Limits<U>::max;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val > current.val || (best.val == current.val && best.index > current.index)) {
if (best.val > current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
@@ -26,11 +27,12 @@ struct ArgMin {
}
template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i=0; i<N; i++) {
IndexValPair<U>
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
best.index = offset+i;
best.index = offset + i;
}
}
return best;
@@ -42,7 +44,8 @@ struct ArgMax {
static constexpr constant U init = Limits<U>::min;
IndexValPair<U> reduce(IndexValPair<U> best, IndexValPair<U> current) {
if (best.val < current.val || (best.val == current.val && best.index > current.index)) {
if (best.val < current.val ||
(best.val == current.val && best.index > current.index)) {
return current;
} else {
return best;
@@ -50,11 +53,12 @@ struct ArgMax {
}
template <int N>
IndexValPair<U> reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i=0; i<N; i++) {
IndexValPair<U>
reduce_many(IndexValPair<U> best, thread U* vals, uint32_t offset) {
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
best.index = offset+i;
best.index = offset + i;
}
}
return best;
@@ -64,19 +68,16 @@ struct ArgMax {
template <typename U>
IndexValPair<U> simd_shuffle_down(IndexValPair<U> data, uint16_t delta) {
return IndexValPair<U>{
simd_shuffle_down(data.index, delta),
simd_shuffle_down(data.val, delta)
};
simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)};
}
template <typename T, typename Op, int N_READS>
[[kernel]] void arg_reduce_general(
const device T *in [[buffer(0)]],
device uint32_t *out [[buffer(1)]],
const device int *shape [[buffer(2)]],
const device size_t *in_strides [[buffer(3)]],
const device size_t *out_strides [[buffer(4)]],
const device T* in [[buffer(0)]],
device uint32_t* out [[buffer(1)]],
const device int* shape [[buffer(2)]],
const device size_t* in_strides [[buffer(3)]],
const device size_t* out_strides [[buffer(4)]],
const device size_t& ndim [[buffer(5)]],
const device size_t& axis_stride [[buffer(6)]],
const device size_t& axis_size [[buffer(7)]],
@@ -86,7 +87,6 @@ template <typename T, typename Op, int N_READS>
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
// Shapes and strides *do not* contain the reduction axis. The reduction size
// and stride are provided in axis_stride and axis_size.
//
@@ -113,13 +113,13 @@ template <typename T, typename Op, int N_READS>
threadgroup IndexValPair<T> local_data[32];
// Loop over the reduction axis in lsize*N_READS buckets
for (uint r=0; r < ceildiv(axis_size, N_READS*lsize); r++) {
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
// Read the current value
uint32_t current_index = r*lsize*N_READS + lid*N_READS;
uint32_t current_index = r * lsize * N_READS + lid * N_READS;
uint32_t offset = current_index;
const device T * current_in = in + in_idx + current_index * axis_stride;
const device T* current_in = in + in_idx + current_index * axis_stride;
T vals[N_READS];
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
vals[i] = (current_index < axis_size) ? *current_in : T(Op::init);
current_index++;
current_in += axis_stride;
@@ -130,7 +130,7 @@ template <typename T, typename Op, int N_READS>
// need to reduce across the thread group.
// First per simd reduction.
for (uint offset=simd_size/2; offset>0; offset/=2) {
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor);
}
@@ -149,7 +149,7 @@ template <typename T, typename Op, int N_READS>
if (simd_lane_id < simd_groups) {
best = local_data[simd_lane_id];
}
for (uint offset=simd_size/2; offset>0; offset/=2) {
for (uint offset = simd_size / 2; offset > 0; offset /= 2) {
IndexValPair<T> neighbor = simd_shuffle_down(best, offset);
best = op.reduce(best, neighbor);
}
@@ -161,24 +161,25 @@ template <typename T, typename Op, int N_READS>
}
#define instantiate_arg_reduce_helper(name, itype, op) \
template [[host_name(name)]] \
[[kernel]] void arg_reduce_general<itype, op<itype>, 4>( \
const device itype *in [[buffer(0)]], \
device uint32_t * out [[buffer(1)]], \
const device int *shape [[buffer(2)]], \
const device size_t *in_strides [[buffer(3)]], \
const device size_t *out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
template [[host_name(name)]] [[kernel]] void \
arg_reduce_general<itype, op<itype>, 4>( \
const device itype* in [[buffer(0)]], \
device uint32_t* out [[buffer(1)]], \
const device int* shape [[buffer(2)]], \
const device size_t* in_strides [[buffer(3)]], \
const device size_t* out_strides [[buffer(4)]], \
const device size_t& ndim [[buffer(5)]], \
const device size_t& axis_stride [[buffer(6)]], \
const device size_t& axis_size [[buffer(7)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_arg_reduce(name, itype) \
// clang-format off
#define instantiate_arg_reduce(name, itype) \
instantiate_arg_reduce_helper("argmin_" #name , itype, ArgMin) \
instantiate_arg_reduce_helper("argmax_" #name , itype, ArgMax)
@@ -193,4 +194,4 @@ instantiate_arg_reduce(int32, int32_t)
instantiate_arg_reduce(int64, int64_t)
instantiate_arg_reduce(float16, half)
instantiate_arg_reduce(float32, float)
instantiate_arg_reduce(bfloat16, bfloat16_t)
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on

View File

@@ -6,7 +6,9 @@
using namespace metal;
#if defined(__HAVE_BFLOAT__)
// No support for less than metal 3.0
// anything greater has native bfloat
#ifndef METAL_3_0
typedef bfloat bfloat16_t;
@@ -312,6 +314,6 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
#pragma METAL internals : disable
#endif // defined(__HAVE_BFLOAT__)
#endif
#include "mlx/backend/metal/kernels/bf16_math.h"

View File

@@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if defined(__HAVE_BFLOAT__)
#ifndef METAL_3_0
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
@@ -391,4 +391,4 @@ instantiate_metal_simd_comm_funcs(
uint16_to_bfloat16);
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
} // namespace metal
} // namespace metal

View File

@@ -229,3 +229,45 @@ struct LogicalOr {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return metal::precise::atan2(y, x);
}
};

View File

@@ -77,7 +77,8 @@ template <typename T, typename U, typename Op>
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
@@ -92,7 +93,8 @@ template <typename T, typename U, typename Op, int DIM>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
@@ -112,111 +114,118 @@ template <typename T, typename U, typename Op>
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] \
[[kernel]] void binary_op_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary(name, itype, otype, op, bopt) \
template \
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void binary_op_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
template [[host_name(name "_" #dims)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name(name "_1")]] \
[[kernel]] void binary_op_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void binary_op_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void binary_op_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
template [[host_name(name "_1")]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] \
[[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op)
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
// clang-format off
#define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
// clang-format off
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
// clang-format off
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
instantiate_binary_float(name, op) // clang-format on
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
// clang-format off
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op)
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
// clang-format off
instantiate_binary_types(add, Add)
instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal)
@@ -232,6 +241,7 @@ instantiate_binary_types(mul, Multiply)
instantiate_binary_types(sub, Subtract)
instantiate_binary_types(pow, Power)
instantiate_binary_types(rem, Remainder)
instantiate_binary_float(arctan2, ArcTan2)
// NaNEqual only needed for floating point types with boolean output
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
@@ -241,3 +251,13 @@ instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
// Bitwise ops only need integer types and bool (except for l/r shift)
instantiate_binary_integer(bitwise_and, BitwiseAnd)
instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd)
instantiate_binary_integer(bitwise_or, BitwiseOr)
instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
instantiate_binary_integer(bitwise_xor, BitwiseXor)
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
instantiate_binary_integer(left_shift, LeftShift)
instantiate_binary_integer(right_shift, RightShift) // clang-format on

View File

@@ -3,28 +3,42 @@
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
struct FloorDivide {
template <typename T> T operator()(T x, T y) { return x / y; }
template <> float operator()(float x, float y) { return trunc(x / y); }
template <> half operator()(half x, half y) { return trunc(x / y); }
template <> bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { return trunc(x / y); }
template <typename T>
T operator()(T x, T y) {
return x / y;
}
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T> operator()(T x, T y) {
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T> operator()(T x, T y) {
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
@@ -32,10 +46,11 @@ struct Remainder {
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
return r;
}
template <> complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
@@ -50,7 +65,6 @@ template <typename T, typename U, typename Op1, typename Op2>
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_ss(
device const T* a,
@@ -139,7 +153,8 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
@@ -156,7 +171,8 @@ template <typename T, typename U, typename Op1, typename Op2, int DIM>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
@@ -180,99 +196,102 @@ template <typename T, typename U, typename Op1, typename Op2>
}
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
template [[host_name(name)]] \
[[kernel]] void binary_op_##bopt<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
template [[host_name(name)]] [[kernel]] void \
binary_op_##bopt<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void binary_op_g_nd<itype, otype, op1, op2, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
template [[host_name(name "_" #dims)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op1, op2, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
template [[host_name(name "_1")]] \
[[kernel]] void binary_op_g_nd1<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void binary_op_g_nd2<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void binary_op_g_nd3<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5)
template [[host_name(name "_1")]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op1, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
#define instantiate_binary_g(name, itype, otype, op1, op2) \
template [[host_name(name)]] \
[[kernel]] void binary_op_g<itype, otype, op2, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
template [[host_name(name)]] [[kernel]] void \
binary_op_g<itype, otype, op2, op2>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2)
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
#define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \
// clang-format off
#define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \
instantiate_binary_all(name, float32, float, float, op1, op2) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2)
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
#define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
// clang-format off
#define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
instantiate_binary_float(name, op1, op2)
instantiate_binary_types(divmod, FloorDivide, Remainder)
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on

View File

@@ -22,7 +22,7 @@ struct complex64_t {
float imag;
// Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
// Conversions to complex64_t
template <

View File

@@ -1,13 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_simdgroup_matrix>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#define MLX_MTL_CONST static constant constexpr const
@@ -23,17 +21,18 @@ template <typename T, int N>
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for(short i = 0; i < N; i++) filter_size *= params->wS[i];
for (short i = 0; i < N; i++)
filter_size *= params->wS[i];
int out_pixels = 1;
for(short i = 0; i < N; i++) out_pixels *= params->oS[i];
for (short i = 0; i < N; i++)
out_pixels *= params->oS[i];
// Set out
// Set out
out += gid.z * filter_size + gid.y * (params->C);
// Corrdinates in input
// Coordinates in input
int is[N] = {0};
// gid.z: N oS (Batch and row in unfolded output)
@@ -46,11 +45,11 @@ template <typename T, int N>
bool valid = n < params->N;
// Unroll dimensions
// Unroll dimensions
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
@@ -64,10 +63,10 @@ template <typename T, int N>
wS /= params->wS[i];
}
if(valid) {
if (valid) {
size_t in_offset = n * params->in_strides[0];
for(int i = 0; i < N; ++i) {
for (int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1];
}
@@ -75,21 +74,91 @@ template <typename T, int N>
} else {
out[gid.x] = T(0);
}
}
#define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] \
[[kernel]] void naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]);
// This kernel unfolds the input array of size (N, *spatial_dims, C)
// into an array of size (N x *spatial_dims, C x *kernel_dims).
template <typename T, int N>
[[kernel]] void naive_unfold_transpose_Nd(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for (short i = 0; i < N; i++)
filter_size *= params->wS[i];
#define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) \
instantiate_naive_unfold_nd(name, itype, 2) \
instantiate_naive_unfold_nd(name, itype, 3)
int out_pixels = 1;
for (short i = 0; i < N; i++)
out_pixels *= params->oS[i];
// Set out
out += gid.z * filter_size + gid.x * (filter_size / params->C);
// Coordinates in input
int is[N] = {0};
// gid.z: N oS (Batch and row in unfolded output)
// gid.y: wS (Filter location to unfold input)
// gid.x: C (channel)
int n = (gid.z) / out_pixels;
int oS = (gid.z) % out_pixels;
int wS = gid.y;
bool valid = n < params->N;
// Unroll dimensions
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_;
int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i];
int is_max = 1 + params->idil[i] * (params->iS[i] - 1);
valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0);
is[i] = is_ / params->idil[i];
oS /= params->oS[i];
wS /= params->wS[i];
out += ws_ * params->str[i];
}
if (valid) {
size_t in_offset = n * params->in_strides[0];
for (int i = 0; i < N; ++i) {
in_offset += is[i] * params->in_strides[i + 1];
}
out[0] = in[in_offset + gid.x];
} else {
out[0] = T(0);
}
}
#define instantiate_naive_unfold_nd(name, itype, n) \
template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \
naive_unfold_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]); \
template \
[[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
naive_unfold_transpose_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
uint3 gid [[thread_position_in_grid]]);
#define instantiate_naive_unfold_nd_dims(name, itype) \
instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \
name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3)
instantiate_naive_unfold_nd_dims(float32, float);
instantiate_naive_unfold_nd_dims(float16, half);
@@ -99,12 +168,13 @@ instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t);
/// Slow and naive conv2d kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const int BC = 16>
[[kernel]] void naive_conv_2d(
const device T* in [[buffer(0)]],
const device T* wt [[buffer(1)]],
@@ -114,7 +184,6 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)simd_gid;
(void)simd_lid;
@@ -123,80 +192,82 @@ template <typename T,
int out_o = tid.y * BN * TN + lid.y * TN;
int out_hw = tid.x * BM * TM + lid.x * TM;
int out_h[TM];
int out_w[TN];
for(int m = 0; m < TM; ++m) {
for (int m = 0; m < TM; ++m) {
int mm = (out_hw + m);
out_h[m] = mm / params.oS[1];
out_w[m] = mm % params.oS[1];
}
T in_local[TM];
T wt_local[TN];
T out_local[TM * TN] = {T(0)};
for(int h = 0; h < params.wS[0]; ++h) {
for(int w = 0; w < params.wS[1]; ++w) {
for(int c = 0; c < params.C; ++c) {
for (int h = 0; h < params.wS[0]; ++h) {
for (int w = 0; w < params.wS[1]; ++w) {
for (int c = 0; c < params.C; ++c) {
// Local in
for(int m = 0; m < TM; m++) {
for (int m = 0; m < TM; m++) {
int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0];
int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1];
bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1];
in_local[m] = valid ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] : T(0);
in_local[m] = valid
? in[i * params.in_strides[1] + j * params.in_strides[2] + c]
: T(0);
}
// Load weight
for (int n = 0; n < TN; ++n) {
int o = out_o + n;
wt_local[n] = o < params.O ? wt[o * params.wt_strides[0] +
h * params.wt_strides[1] +
w * params.wt_strides[2] + c] : T(0);
wt_local[n] = o < params.O
? wt[o * params.wt_strides[0] + h * params.wt_strides[1] +
w * params.wt_strides[2] + c]
: T(0);
}
// Accumulate
for(int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) {
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
out_local[m * TN + n] += in_local[m] * wt_local[n];
}
}
}
}
}
for(int m = 0; m < TM; ++m) {
for(int n = 0; n < TN; ++n) {
if(out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && (out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] = out_local[m * TN + n];
for (int m = 0; m < TM; ++m) {
for (int n = 0; n < TN; ++n) {
if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] &&
(out_o + n) < params.O)
out[out_h[m] * params.out_strides[1] +
out_w[m] * params.out_strides[2] + out_o + n] =
out_local[m * TN + n];
}
}
}
// Instantiations
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \
[[kernel]] void naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \
template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
naive_conv_2d<itype, bm, bn, tm, tn>( \
const device itype* in [[buffer(0)]], \
const device itype* wt [[buffer(1)]], \
device itype* out [[buffer(2)]], \
const constant MLXConvParams<2>& params [[buffer(3)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_naive_conv_2d_blocks(name, itype) \
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \
instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4)
instantiate_naive_conv_2d_blocks(float32, float);
instantiate_naive_conv_2d_blocks(float16, half);
@@ -207,9 +278,7 @@ instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
template <int M, int R, int S>
struct WinogradTransforms {
};
struct WinogradTransforms {};
template <>
struct WinogradTransforms<6, 3, 8> {
@@ -218,36 +287,36 @@ struct WinogradTransforms<6, 3, 8> {
MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1;
MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8;
MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
{ 0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
{ 5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
{ 0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f},
{-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f},
{0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f},
{5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f},
{0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f},
{-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f},
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
};
MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{ 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
{ 1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
{ 1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
{ 1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
{ 1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
{ 1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
{ 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
{1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f},
{1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f},
{1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f},
{1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f},
{1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f},
{1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f},
{1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f},
{0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f},
};
MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = {
{ 1.00, 0.00, 0.00},
{ -2.0/9.00, -2.0/9.00, -2.0/9.00},
{ -2.0/9.00, 2.0/9.00, -2.0/9.00},
{ 1.0/90.0, 1.0/45.0, 2.0/45.0},
{ 1.0/90.0, -1.0/45.0, 2.0/45.0},
{ 32.0/45.0, 16.0/45.0, 8.0/45.0},
{ 32.0/45.0, -16.0/45.0, 8.0/45.0},
{ 0.00, 0.00, 1.00},
{1.00, 0.00, 0.00},
{-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00},
{-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00},
{1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0},
{1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0},
{32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0},
{32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0},
{0.00, 0.00, 1.00},
};
};
@@ -255,12 +324,9 @@ constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8];
constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8];
template <typename T,
int BC = 32,
int BO = 4,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void winograd_conv_2d_weight_transform(
template <typename T, int BC = 32, int BO = 4, int M = 6, int R = 3>
[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void
winograd_conv_2d_weight_transform(
const device T* wt_in [[buffer(0)]],
device T* wt_out [[buffer(1)]],
const constant int& C [[buffer(2)]],
@@ -268,7 +334,6 @@ template <typename T,
uint tid [[threadgroup_position_in_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
using WGT = WinogradTransforms<M, R, 8>;
// Get lane position in simdgroup
@@ -288,35 +353,37 @@ template <typename T,
// Move to the correct output filter
size_t ko = BO * tid + simd_group_id;
wt_in += ko * R * R * C;
wt_in += ko * R * R * C;
// wt_out is stored transposed (A x A x C x O)
short ohw_0 = sm * 8 + sn;
short ohw_1 = sm * 8 + sn + 1;
device T* wt_out_0 = wt_out + ohw_0 * C * O + ko;
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
device T* wt_out_1 = wt_out + ohw_1 * C * O + ko;
// Prepare shared memory
threadgroup T Ws[BO][R][R][BC];
// Loop over C
for(int bc = 0; bc < C; bc += BC) {
for (int bc = 0; bc < C; bc += BC) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read into shared memory
for(int kh = 0; kh < R; ++kh) {
for(int kw = 0; kw < R; ++kw) {
for(int kc = simd_lane_id; kc < BC; kc += 32) {
for (int kh = 0; kh < R; ++kh) {
for (int kw = 0; kw < R; ++kw) {
for (int kc = simd_lane_id; kc < BC; kc += 32) {
Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = 0; c < BC; ++c) {
// Do transform and store the result
for (int c = 0; c < BC; ++c) {
simdgroup_matrix<T, 8, 8> g;
g.thread_elements()[0] = sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] = sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
g.thread_elements()[0] =
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] =
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = g_out.thread_elements()[0];
@@ -327,27 +394,23 @@ template <typename T,
wt_out_0 += BC * O;
wt_out_1 += BC * O;
}
}
#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \
template [[host_name("winograd_conv_2d_weight_transform_" #name "_bc" #bc)]]\
[[kernel]] void winograd_conv_2d_weight_transform<itype, bc>(\
const device itype* wt_in [[buffer(0)]],\
device itype* wt_out [[buffer(1)]],\
const constant int& C [[buffer(2)]],\
const constant int& O [[buffer(3)]],\
uint tid [[threadgroup_position_in_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
template [[host_name("winograd_conv_2d_weight_transform_" #name \
"_bc" #bc)]] [[kernel]] void \
winograd_conv_2d_weight_transform<itype, bc>( \
const device itype* wt_in [[buffer(0)]], \
device itype* wt_out [[buffer(1)]], \
const constant int& C [[buffer(2)]], \
const constant int& O [[buffer(3)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T,
int BC,
int WM,
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_input_transform(
template <typename T, int BC, int WM, int WN, int M = 6, int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
winograd_conv_2d_input_transform(
const device T* inp_in [[buffer(0)]],
device T* inp_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]],
@@ -356,7 +419,6 @@ template <typename T,
uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid;
using WGT = WinogradTransforms<M, R, 8>;
@@ -387,46 +449,48 @@ template <typename T,
int bw = M * tid.x + kw;
// Move to the correct input tile
inp_in += tid.z * params.in_strides[0]
+ bh * params.in_strides[1]
+ bw * params.in_strides[2];
inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] +
bw * params.in_strides[2];
// Pre compute strides
// Pre compute strides
int jump_in[TH][TW];
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) {
jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2];
}
}
// inp_out is stored interleaved (A x A x tiles x C)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t tile_id =
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1;
device T* inp_out_0 = inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
device T* inp_out_1 = inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
device T* inp_out_0 =
inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C;
device T* inp_out_1 =
inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C;
// Prepare shared memory
threadgroup T Is[A][A][BC];
// Loop over C
for(int bc = 0; bc < params.C; bc += BC) {
for (int bc = 0; bc < params.C; bc += BC) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read into shared memory
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) {
const device T* in_ptr = inp_in + jump_in[h][w];
for(int c = simd_lane_id; c < BC; c += 32) {
for (int c = simd_lane_id; c < BC; c += 32) {
Is[kh + h][kw + w][c] = in_ptr[c];
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
// Do transform and store the result
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> I;
I.thread_elements()[0] = Is[sm][sn][c];
I.thread_elements()[1] = Is[sm][sn + 1][c];
@@ -440,28 +504,24 @@ template <typename T,
inp_out_0 += BC;
inp_out_1 += BC;
}
}
#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \
template [[host_name("winograd_conv_2d_input_transform_" #name "_bc" #bc)]]\
[[kernel]] void winograd_conv_2d_input_transform<itype, bc, 2, 2>(\
const device itype* inp_in [[buffer(0)]],\
device itype* inp_out [[buffer(1)]],\
const constant MLXConvParams<2>& params [[buffer(2)]],\
uint3 tid [[threadgroup_position_in_grid]],\
uint3 lid [[thread_position_in_threadgroup]],\
uint3 tgp_per_grid [[threadgroups_per_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
template [[host_name("winograd_conv_2d_input_transform_" #name \
"_bc" #bc)]] [[kernel]] void \
winograd_conv_2d_input_transform<itype, bc, 2, 2>( \
const device itype* inp_in [[buffer(0)]], \
device itype* inp_out [[buffer(1)]], \
const constant MLXConvParams<2>& params [[buffer(2)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_per_grid [[threadgroups_per_grid]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]);
template <typename T,
int BO,
int WM,
int WN,
int M = 6,
int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void winograd_conv_2d_output_transform(
template <typename T, int BO, int WM, int WN, int M = 6, int R = 3>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
winograd_conv_2d_output_transform(
const device T* out_in [[buffer(0)]],
device T* out_out [[buffer(1)]],
const constant MLXConvParams<2>& params [[buffer(2)]],
@@ -470,7 +530,6 @@ template <typename T,
uint3 tgp_per_grid [[threadgroups_per_grid]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]]) {
(void)lid;
using WGT = WinogradTransforms<M, R, 8>;
@@ -503,57 +562,59 @@ template <typename T,
int bw = M * tid.x + kw;
// Move to the correct input tile
out_out += tid.z * params.out_strides[0]
+ bh * params.out_strides[1]
+ bw * params.out_strides[2];
out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] +
bw * params.out_strides[2];
// Pre compute strides
// Pre compute strides
int jump_in[TH][TW];
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) {
bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]);
jump_in[h][w] = valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
jump_in[h][w] =
valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1;
}
}
// out_in is stored interleaved (A x A x tiles x O)
size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z;
size_t tile_id = tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t tile_id =
tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x;
size_t ohw_0 = sm * 8 + sn;
size_t ohw_1 = sm * 8 + sn + 1;
const device T* out_in_0 = out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
const device T* out_in_1 = out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
const device T* out_in_0 =
out_in + ohw_0 * N_TILES * params.O + tile_id * params.O;
const device T* out_in_1 =
out_in + ohw_1 * N_TILES * params.O + tile_id * params.O;
// Prepare shared memory
threadgroup T Os[M][M][BO];
// Loop over O
for(int bo = 0; bo < params.O; bo += BO) {
for (int bo = 0; bo < params.O; bo += BO) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for(int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
// Do transform and store the result
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
simdgroup_matrix<T, 8, 8> O_mat;
O_mat.thread_elements()[0] = out_in_0[c];
O_mat.thread_elements()[1] = out_in_1[c];
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
if((sm < M) && (sn < M)) {
if ((sm < M) && (sn < M)) {
Os[sm][sn][c] = O_out.thread_elements()[0];
}
if((sm < M) && ((sn + 1) < M)) {
if ((sm < M) && ((sn + 1) < M)) {
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read out from shared memory
for(int h = 0; h < TH; h++) {
for(int w = 0; w < TW; w++) {
if(jump_in[h][w] >= 0) {
for (int h = 0; h < TH; h++) {
for (int w = 0; w < TW; w++) {
if (jump_in[h][w] >= 0) {
device T* out_ptr = out_out + jump_in[h][w];
for(int c = simd_lane_id; c < BO; c += 32) {
for (int c = simd_lane_id; c < BO; c += 32) {
out_ptr[c] = Os[kh + h][kw + w][c];
}
}
@@ -564,25 +625,27 @@ template <typename T,
out_in_0 += BO;
out_in_1 += BO;
}
}
#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \
template [[host_name("winograd_conv_2d_output_transform_" #name "_bo" #bo)]]\
[[kernel]] void winograd_conv_2d_output_transform<itype, bo, 2, 2>(\
const device itype* out_in [[buffer(0)]],\
device itype* out_out [[buffer(1)]],\
const constant MLXConvParams<2>& params [[buffer(2)]],\
uint3 tid [[threadgroup_position_in_grid]],\
uint3 lid [[thread_position_in_threadgroup]],\
uint3 tgp_per_grid [[threadgroups_per_grid]],\
uint simd_group_id [[simdgroup_index_in_threadgroup]],\
template [[host_name("winograd_conv_2d_output_transform_" #name \
"_bo" #bo)]] [[kernel]] void \
winograd_conv_2d_output_transform<itype, bo, 2, 2>( \
const device itype* out_in [[buffer(0)]], \
device itype* out_out [[buffer(1)]], \
const constant MLXConvParams<2>& params [[buffer(2)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_per_grid [[threadgroups_per_grid]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]]);
#define instantiate_winograd_conv_2d(name, itype) \
// clang-format off
#define instantiate_winograd_conv_2d(name, itype) \
instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
instantiate_winograd_conv_2d_output_transform(name, itype, 32)
instantiate_winograd_conv_2d_input_transform(name, itype, 32) \
instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on
// clang-format off
instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(float16, half);
instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@@ -49,7 +49,8 @@ template <typename T, typename U>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -62,7 +63,8 @@ template <typename T, typename U, int DIM>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -76,7 +78,8 @@ template <typename T, typename U>
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
@@ -143,116 +146,110 @@ template <typename T, typename U>
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] \
[[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] \
[[kernel]] void copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \
copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] [[kernel]] void \
copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] \
[[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] \
[[kernel]] void copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] \
[[kernel]] void copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] \
[[kernel]] void copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] \
[[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] \
[[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
#define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_all(tname, itype, otype) \
// clang-format off
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("scopy" #tname, itype, otype, s) \
instantiate_copy("vcopy" #tname, itype, otype, v) \
instantiate_copy_g("gcopy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype)
instantiate_copy_g("gcopy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \
// clang-format off
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
instantiate_copy_all(itname ##uint16, itype, uint16_t) \
instantiate_copy_all(itname ##uint32, itype, uint32_t) \
instantiate_copy_all(itname ##uint64, itype, uint64_t) \
instantiate_copy_all(itname ##int8, itype, int8_t) \
instantiate_copy_all(itname ##int16, itype, int16_t) \
instantiate_copy_all(itname ##int32, itype, int32_t) \
instantiate_copy_all(itname ##int64, itype, int64_t) \
instantiate_copy_all(itname ##float16, itype, half) \
instantiate_copy_all(itname ##float32, itype, float) \
instantiate_copy_all(itname ##bfloat16, itype, bfloat16_t) \
instantiate_copy_all(itname ##complex64, itype, complex64_t)
@@ -268,4 +265,4 @@ instantiate_copy_itype(int64, int64_t)
instantiate_copy_itype(float16, half)
instantiate_copy_itype(float32, float)
instantiate_copy_itype(bfloat16, bfloat16_t)
instantiate_copy_itype(complex64, complex64_t)
instantiate_copy_itype(complex64, complex64_t) // clang-format on

View File

@@ -6,9 +6,8 @@
// - VkFFT (https://github.com/DTolm/VkFFT)
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
#include <metal_math>
#include <metal_common>
#include <metal_math>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
@@ -23,7 +22,7 @@ float2 complex_mul(float2 a, float2 b) {
}
float2 get_twiddle(int k, int p) {
float theta = -1.0f * k * M_PI_F / (2*p);
float theta = -1.0f * k * M_PI_F / (2 * p);
float2 twiddle;
twiddle.x = metal::fast::cos(theta);
@@ -32,7 +31,12 @@ float2 get_twiddle(int k, int p) {
}
// single threaded radix2 implemetation
void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
void radix2(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
@@ -53,11 +57,16 @@ void radix2(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
}
// single threaded radix4 implemetation
void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float2* write_buf) {
void radix4(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
float2 x_2 = read_buf[i + 2*m];
float2 x_3 = read_buf[i + 3*m];
float2 x_2 = read_buf[i + 2 * m];
float2 x_3 = read_buf[i + 3 * m];
// The index within this sub-DFT
int k = i & (p - 1);
@@ -90,11 +99,10 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
write_buf[j] = y_0;
write_buf[j + p] = y_1;
write_buf[j + 2*p] = y_2;
write_buf[j + 3*p] = y_3;
write_buf[j + 2 * p] = y_2;
write_buf[j + 3 * p] = y_3;
}
// Each FFT is computed entirely in shared GPU memory.
//
// N is decomposed into radix-2 and radix-4 DFTs:
@@ -107,11 +115,10 @@ void radix4(int i, int p, int m, threadgroup float2* read_buf, threadgroup float
// steps at compile time for a ~20% performance boost.
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
[[kernel]] void fft(
const device float2 *in [[buffer(0)]],
device float2 * out [[buffer(1)]],
const device float2* in [[buffer(0)]],
device float2* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]],
uint3 threads_per_grid [[threads_per_grid]]) {
// Index of the DFT in batch
int batch_idx = thread_position_in_grid.x * n;
// The index in the DFT we're working on
@@ -132,16 +139,16 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
// Copy input into shared memory
shared_in[i] = in[batch_idx + i];
shared_in[i + m] = in[batch_idx + i + m];
shared_in[i + 2*m] = in[batch_idx + i + 2*m];
shared_in[i + 3*m] = in[batch_idx + i + 3*m];
shared_in[i + 2 * m] = in[batch_idx + i + 2 * m];
shared_in[i + 3 * m] = in[batch_idx + i + 3 * m];
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
for (size_t r = 0; r < radix_2_steps; r++) {
radix2(i, p, m*2, read_buf, write_buf);
radix2(i + m, p, m*2, read_buf, write_buf);
radix2(i, p, m * 2, read_buf, write_buf);
radix2(i + m, p, m * 2, read_buf, write_buf);
p *= 2;
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -167,29 +174,26 @@ template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
// Copy shared memory to output
out[batch_idx + i] = read_buf[i];
out[batch_idx + i + m] = read_buf[i + m];
out[batch_idx + i + 2*m] = read_buf[i + 2*m];
out[batch_idx + i + 3*m] = read_buf[i + 3*m];
out[batch_idx + i + 2 * m] = read_buf[i + 2 * m];
out[batch_idx + i + 3 * m] = read_buf[i + 3 * m];
}
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
template [[host_name("fft_" #name)]] \
[[kernel]] void fft<n, radix_2_steps, radix_4_steps>( \
const device float2* in [[buffer(0)]], \
device float2* out [[buffer(1)]], \
uint3 thread_position_in_grid [[thread_position_in_grid]], \
uint3 threads_per_grid [[threads_per_grid]]);
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
template [[host_name("fft_" #name)]] [[kernel]] void \
fft<n, radix_2_steps, radix_4_steps>( \
const device float2* in [[buffer(0)]], \
device float2* out [[buffer(1)]], \
uint3 thread_position_in_grid [[thread_position_in_grid]], \
uint3 threads_per_grid [[threads_per_grid]]);
// Explicitly define kernels for each power of 2.
// clang-format off
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
instantiate_fft(8, 8, 1, 1)
instantiate_fft(16, 16, 0, 2)
instantiate_fft(32, 32, 1, 2)
instantiate_fft(64, 64, 0, 3)
instantiate_fft(128, 128, 1, 3)
instantiate_fft(256, 256, 0, 4)
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
instantiate_fft(512, 512, 1, 4)
instantiate_fft(1024, 1024, 0, 5)
// 2048 is the max that will fit into 32KB of threadgroup memory.
// TODO: implement 4 step FFT for larger n.
instantiate_fft(2048, 2048, 1, 5)
instantiate_fft(2048, 2048, 1, 5) // clang-format on

View File

@@ -14,17 +14,16 @@ using namespace metal;
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T *src [[buffer(0)]],
device T *out [[buffer(1)]],
const constant int *src_shape [[buffer(2)]],
const constant size_t *src_strides [[buffer(3)]],
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int *slice_sizes [[buffer(5)]],
const constant int *axes [[buffer(6)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
@@ -43,93 +42,78 @@ METAL_FUNC void gather_impl(
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}
#define make_gather_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
[[kernel]] void gather( \
const device T *src [[buffer(0)]], \
device T *out [[buffer(1)]], \
const constant int *src_shape [[buffer(2)]], \
const constant size_t *src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int *slice_sizes [[buffer(5)]], \
const constant int *axes [[buffer(6)]], \
const constant int *idx_shapes [[buffer(7)]], \
const constant size_t *idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \
out, \
src_shape, \
src_strides, \
src_ndim, \
slice_sizes, \
axes, \
idxs, \
index, \
grid_dim); \
}
#define make_gather_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
[[kernel]] void gather( \
const device T* src [[buffer(0)]], \
device T* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \
const constant size_t* src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int* slice_sizes [[buffer(5)]], \
const constant int* axes [[buffer(6)]], \
const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
\
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \
out, \
src_shape, \
src_strides, \
src_ndim, \
slice_sizes, \
axes, \
idxs, \
index, \
grid_dim); \
}
#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
make_gather(0)
make_gather(1)
make_gather(2)
make_gather(3)
make_gather(4)
make_gather(5)
make_gather(6)
make_gather(7)
make_gather(8)
make_gather(9)
make_gather(10)
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
make_gather(10)
/////////////////////////////////////////////////////////////////////
// Gather instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] \
[[kernel]] void gather<src_t, idx_t, nidx, nd>( \
const device src_t *src [[buffer(0)]], \
device src_t *out [[buffer(1)]], \
const constant int *src_shape [[buffer(2)]], \
const constant size_t *src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int *slice_sizes [[buffer(5)]], \
const constant int *axes [[buffer(6)]], \
const constant int *idx_shapes [[buffer(7)]], \
const constant size_t *idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
gather<src_t, idx_t, nidx, nd>( \
const device src_t* src [[buffer(0)]], \
device src_t* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \
const constant size_t* src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int* slice_sizes [[buffer(5)]], \
const constant int* axes [[buffer(6)]], \
const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name)
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
// clang-format off
#define instantiate_gather4(name, src_t, idx_t, nidx) \
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
@@ -148,29 +132,31 @@ instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
// clang-format off
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10)
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
// clang-format off
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
@@ -184,4 +170,4 @@ instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t)
instantiate_gather(bfloat16, bfloat16_t) // clang-format on

View File

@@ -1,12 +1,14 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
@@ -18,33 +20,34 @@ using namespace metal;
MLX_MTL_CONST int SIMD_SIZE = 32;
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN , /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
struct GEMVKernel {
static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE");
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
// - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for
// the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partially overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * TN * 2;
@@ -52,7 +55,7 @@ struct GEMVKernel {
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
@@ -64,14 +67,13 @@ struct GEMVKernel {
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
// Appease compiler
(void)lid;
// Threadgroup in_vec cache
threadgroup T* in_vec_block = tgp_memory + simd_lid * TN * 2;
// Thread local accumulation results
// Thread local accumulation results
thread T result[TM] = {0};
thread T inter[TN];
thread T v_coeff[TN];
@@ -80,7 +82,7 @@ struct GEMVKernel {
int out_row = (tid.x * BM + simd_gid) * TM;
// Exit simdgroup if rows out of bound
if(out_row >= out_vec_size)
if (out_row >= out_vec_size)
return;
// Adjust tail simdgroup to ensure in bound reads
@@ -90,89 +92,81 @@ struct GEMVKernel {
mat += out_row * marix_ld;
// Loop over in_vec in blocks of BN * TN
for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
for (int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Prefetch in_vector for threadgroup use
if(simd_gid == 0) {
if (simd_gid == 0) {
// Main load loop
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
if (bn + TN <= in_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = in_vec[bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
in_vec_block[tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load for all rows
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
v_coeff[tn] = in_vec_block[tn];
}
// Per thread work loop
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
// Load for the row
if(bn + TN <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
// Load for the row
if (bn + TN <= in_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[tm * marix_ld + bn + tn];
}
} else { // Edgecase
#pragma clang loop unroll(full)
for(int tn = 0; tn < TN; tn++) {
int col_idx = (bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
int col_idx =
(bn + tn) < in_vec_size ? (bn + tn) : (in_vec_size - 1);
inter[tn] = mat[tm * marix_ld + col_idx];
}
}
// Accumulate results
for(int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tn = 0; tn < TN; tn++) {
result[tm] += inter[tn] * v_coeff[tn];
}
}
}
// Simdgroup accumulations
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
result[tm] = simd_sum(result[tm]);
}
// Write outputs
if(simd_lid == 0) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
if(kDoAxpby) {
out_vec[out_row + tm] =
static_cast<T>(alpha) * result[tm] +
if (simd_lid == 0) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
if (kDoAxpby) {
out_vec[out_row + tm] = static_cast<T>(alpha) * result[tm] +
static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
} else {
out_vec[out_row + tm] = result[tm];
}
}
}
}
};
///////////////////////////////////////////////////////////////////////////////
@@ -180,32 +174,31 @@ struct GEMVKernel {
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
struct GEMVTKernel {
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
// into blocks of (BM * TM, BN * TN) divided among threadgroups
// - Every thread works on a block of (TM, TN)
// - We assume each thead group is launched with (BN, BM, 1) threads
//
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then multiplies and adds to accumulate its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across the rows
// These are then summed up across the threadgroup
// 1. A thread loads TN elements each from mat along TM contiguous rows
// and the corresponding scalar from the vector
// 2. The thread then accumulates its local result for the block
// 3. At the end, each thread has accumulated results over all blocks across
// the rows. These are then summed up across the threadgroup
// 4. Each threadgroup writes its accumulated BN * TN outputs
//
// Edge case handling:
// - The threadgroup with the largest tid will have blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results remain zero)
// * The last thread that partially overlaps with the matrix is shifted inwards
// such that the thread block fits exactly in the matrix
// - The threadgroup with the largest tid has blocks that exceed the matrix
// * The blocks that start outside the matrix are never read (thread results
// remain zero)
// * The last thread that partially overlaps with the matrix is shifted
// inwards such that the thread block fits exactly in the matrix
MLX_MTL_CONST short tgp_mem_size = BN * BM * TN;
@@ -213,7 +206,7 @@ struct GEMVTKernel {
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
@@ -225,8 +218,7 @@ struct GEMVTKernel {
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
// Appease compiler
// Appease compiler
(void)simd_gid;
(void)simd_lid;
@@ -243,77 +235,69 @@ struct GEMVTKernel {
// Edgecase handling
if (out_col < out_vec_size) {
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
// Per thread accumulation main loop
int bm = in_row;
for(; bm < in_vec_size; bm += BM * TM) {
for (; bm < in_vec_size; bm += BM * TM) {
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if(bm + TM <= in_vec_size) {
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
if (bm + TM <= in_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
v_coeff[tm] = in_vec[bm + tm];
}
#pragma clang loop unroll(full)
for(int tm = 0; tm < TM; tm++) {
for(int tn = 0; tn < TN; tn++) {
MLX_MTL_PRAGMA_UNROLL
for (int tm = 0; tm < TM; tm++) {
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
} else { // Edgecase handling
for(int tm = 0; bm + tm < in_vec_size; tm++) {
for (int tm = 0; bm + tm < in_vec_size; tm++) {
v_coeff[tm] = in_vec[bm + tm];
for(int tn = 0; tn < TN; tn++) {
for (int tn = 0; tn < TN; tn++) {
inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
}
for(int tn = 0; tn < TN; tn++) {
for (int tn = 0; tn < TN; tn++) {
result[tn] += v_coeff[tm] * inter[tn];
}
}
}
}
}
// Threadgroup collection
#pragma clang loop unroll(full)
for(int i = 0; i < TN; i++) {
MLX_MTL_PRAGMA_UNROLL
for (int i = 0; i < TN; i++) {
tgp_results[lid.y * TN + i] = result[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Threadgroup accumulation and writing out results
if(lid.y == 0 && out_col < out_vec_size) {
#pragma clang loop unroll(full)
for(int i = 1; i < BM; i++) {
#pragma clang loop unroll(full)
for(int j = 0; j < TN; j++) {
if (lid.y == 0 && out_col < out_vec_size) {
MLX_MTL_PRAGMA_UNROLL
for (int i = 1; i < BM; i++) {
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
result[j] += tgp_results[i * TN + j];
}
}
#pragma clang loop unroll(full)
for(int j = 0; j < TN; j++) {
if(kDoAxpby) {
out_vec[out_col + j] =
static_cast<T>(alpha) * result[j] +
MLX_MTL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
if (kDoAxpby) {
out_vec[out_col + j] = static_cast<T>(alpha) * result[j] +
static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
} else {
out_vec[out_col + j] = result[j];
@@ -328,18 +312,18 @@ struct GEMVTKernel {
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch, /* Batch ndim > 1 */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv(
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
@@ -355,16 +339,15 @@ template <
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
if(kDoNCBatch) {
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if(kDoAxpby) {
if (kDoAxpby) {
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
}
@@ -372,89 +355,215 @@ template <
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if(kDoAxpby) {
if (kDoAxpby) {
bias += tid.z * bias_batch_stride[0];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
gemv<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
[[kernel]] void gemv<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
// clang-format off
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1)
instantiate_gemv_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
// clang-format off
#define instantiate_gemv_blocks(name, itype) \
instantiate_gemv(name, itype, 4, 32, 1, 4) \
instantiate_gemv(name, itype, 4, 32, 4, 4) \
instantiate_gemv(name, itype, 8, 32, 4, 4)
instantiate_gemv(name, itype, 8, 32, 4, 4) // clang-format on
instantiate_gemv_blocks(float32, float);
instantiate_gemv_blocks(float16, half);
instantiate_gemv_blocks(bfloat16, bfloat16_t);
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_bs(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]],
const constant size_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVKernel<T, BM, BN, TM, TN, false>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
uint32_t indx_mat;
// Update batch offsets
if (batch_ndim > 1) {
const constant size_t* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
indx_vec = vec_indices[batch_offsets.x];
indx_mat = mat_indices[batch_offsets.y];
} else {
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
}
if (vector_batch_ndim > 1) {
in_vec += elem_to_loc(
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
} else {
in_vec += indx_vec * vector_batch_stride[0];
}
if (matrix_batch_ndim > 1) {
mat += elem_to_loc(
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
} else {
mat += indx_mat * matrix_batch_stride[0];
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
batch_ndim, // Not used
tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, tm, tn) \
template [[host_name("gemv_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
gemv_bs<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_bs_blocks(name, itype) \
instantiate_gemv_bs_helper(name, itype, 4, 32, 1, 4) \
instantiate_gemv_bs_helper(name, itype, 4, 32, 4, 4) \
instantiate_gemv_bs_helper(name, itype, 8, 32, 4, 4) // clang-format on
instantiate_gemv_bs_blocks(float32, float);
instantiate_gemv_bs_blocks(float16, half);
instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
///////////////////////////////////////////////////////////////////////////////
/// Vector matrix multiplication
///////////////////////////////////////////////////////////////////////////////
template <
typename T,
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN, /* Thread cols (in elements) */
const bool kDoNCBatch, /* Batch ndim > 1 */
const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
[[kernel, max_total_threads_per_threadgroup(BM * BN)]] void gemv_t(
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
@@ -470,16 +579,15 @@ template <
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, kDoAxpby>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
// Update batch offsets
if(kDoNCBatch) {
if (kDoNCBatch) {
in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
if(kDoAxpby) {
if (kDoAxpby) {
bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
}
@@ -487,70 +595,202 @@ template <
in_vec += tid.z * vector_batch_stride[0];
mat += tid.z * matrix_batch_stride[0];
if(kDoAxpby) {
if (kDoAxpby) {
bias += tid.z * bias_batch_stride[0];
}
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid
);
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
bias_stride,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby)]] \
[[kernel]] void gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, nc, axpby) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn \
"_nc" #nc "_axpby" #axpby)]] [[kernel]] void \
gemv_t<itype, bm, bn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
// clang-format off
#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 0, 1) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1)
instantiate_gemv_t_helper(name, itype, bm, bn, tm, tn, 1, 1) // clang-format on
// clang-format off
#define instantiate_gemv_t_blocks(name, itype) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t(name, itype, 8, 128, 4, 4)
instantiate_gemv_t(name, itype, 8, 128, 4, 4) // clang-format on
// clang-format off
instantiate_gemv_t_blocks(float32, float);
instantiate_gemv_t_blocks(float16, half);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
instantiate_gemv_t_blocks(bfloat16, bfloat16_t); // clang-format on
template <
typename T,
const int BM, /* Threadgroup rows (in threads) */
const int BN, /* Threadgroup cols (in threads) */
const int TM, /* Thread rows (in elements) */
const int TN> /* Thread cols (in elements) */
[[kernel, max_total_threads_per_threadgroup(BM* BN)]] void gemv_t_bs(
const device T* mat [[buffer(0)]],
const device T* in_vec [[buffer(1)]],
const device T* bias [[buffer(2)]],
device T* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant float& alpha [[buffer(7)]],
const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]],
const constant size_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using gemv_kernel = GEMVTKernel<T, BM, BN, TM, TN, false>;
threadgroup T tgp_memory[gemv_kernel::tgp_mem_size];
uint32_t indx_vec;
uint32_t indx_mat;
// Update batch offsets
if (batch_ndim > 1) {
const constant size_t* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
indx_vec = vec_indices[batch_offsets.x];
indx_mat = mat_indices[batch_offsets.y];
} else {
indx_vec = vec_indices[index_batch_strides[0] * tid.z];
indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
}
if (vector_batch_ndim > 1) {
in_vec += elem_to_loc(
indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
} else {
in_vec += indx_vec * vector_batch_stride[0];
}
if (matrix_batch_ndim > 1) {
mat += elem_to_loc(
indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
} else {
mat += indx_mat * matrix_batch_stride[0];
}
out_vec += tid.z * out_vec_size;
gemv_kernel::run(
mat,
in_vec,
bias,
out_vec,
in_vec_size,
out_vec_size,
marix_ld,
alpha,
beta,
batch_ndim, // Not used,
tgp_memory,
tid,
lid,
simd_gid,
simd_lid);
}
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, tm, tn) \
template [[host_name("gemv_t_bs_" #nm "_bm" #bm "_bn" #bn "_tm" #tm \
"_tn" #tn)]] [[kernel]] void \
gemv_t_bs<itype, bm, bn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off
#define instantiate_gemv_t_bs_blocks(name, itype) \
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 1) \
instantiate_gemv_t_bs_helper(name, itype, 8, 8, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 16, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 32, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 64, 4, 4) \
instantiate_gemv_t_bs_helper(name, itype, 8, 128, 4, 4) // clang-format on
// clang-format off
instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -99,7 +99,8 @@ template <typename T, int N_READS = RMS_N_READS>
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
out[i] = w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
out[i] =
w[w_stride * i] * static_cast<T>(thread_x[i]) + b[b_stride * i];
}
}
}
@@ -192,13 +193,15 @@ template <typename T, int N_READS = RMS_N_READS>
if (r + lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
float xi = (x[r + i] - mean) * normalizer;
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
out[r + i] =
w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((r + lid * N_READS + i) < axis_size) {
float xi = (x[r + i] - mean) * normalizer;
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) + b[b_stride * (i + r)];
out[r + i] = w[w_stride * (i + r)] * static_cast<T>(xi) +
b[b_stride * (i + r)];
}
}
}
@@ -323,16 +326,18 @@ template <typename T, int N_READS = RMS_N_READS>
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
thread_x[i] = (thread_x[i] - mean) * normalizer;
gx[i] = static_cast<T>(normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gx[i] = static_cast<T>(
normalizer * (thread_w[i] * thread_g[i] - meanwg) -
thread_x[i] * meanwgxc * normalizer2);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i]);
}
}
@@ -460,8 +465,8 @@ template <typename T, int N_READS = RMS_N_READS>
float xi = (x[i + r] - mean) * normalizer;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
xi * meanwgxc * normalizer2);
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
}
} else {
@@ -470,8 +475,8 @@ template <typename T, int N_READS = RMS_N_READS>
float xi = (x[i + r] - mean) * normalizer;
float wi = w[(i + r) * w_stride];
float gi = g[i + r];
gx[i + r] = static_cast<T>(normalizer * (wi * gi - meanwg) -
xi * meanwgxc * normalizer2);
gx[i + r] = static_cast<T>(
normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2);
gw[i + r] = static_cast<T>(gi * xi);
}
}
@@ -548,6 +553,4 @@ template <typename T, int N_READS = RMS_N_READS>
instantiate_layer_norm(float32, float)
instantiate_layer_norm(float16, half)
instantiate_layer_norm(bfloat16, bfloat16_t)
// clang-format on
instantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on

File diff suppressed because it is too large Load Diff

View File

@@ -3,9 +3,8 @@
#include "mlx/backend/metal/kernels/utils.h"
static constexpr constant uint32_t rotations[2][4] = {
{13, 15, 26, 6},
{17, 29, 16, 24}
};
{13, 15, 26, 6},
{17, 29, 16, 24}};
union rbits {
uint2 val;
@@ -13,7 +12,6 @@ union rbits {
};
rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA};
rbits v;
@@ -51,7 +49,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
out[4 * count.x + i] = bits.bytes[0][i];
}
if (!drop_last) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];
@@ -87,7 +85,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
out[4 * count.x + i] = bits.bytes[0][i];
}
if (!drop_last) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) {
int edge_bytes = (bytes_per_key % 4);
for (int i = 0; i < edge_bytes; ++i) {
out[4 * count.y + i] = bits.bytes[1][i];

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
@@ -60,14 +60,13 @@ METAL_FUNC U per_thread_all_reduce(
// All reduce kernel
///////////////////////////////////////////////////////////////////////////////
// NB: This kernel assumes threads_per_threadgroup is at most
// 1024. This way with a simd_size of 32, we are guaranteed to
// complete the reduction in two steps of simd-level reductions.
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
@@ -75,11 +74,11 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
U total_val =
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group
total_val = op.simd_reduce(total_val);
@@ -98,10 +97,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
@@ -110,14 +109,16 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]) {
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
U total_val =
per_thread_all_reduce<T, U, Op, N_READS>(in, in_size, gid, grid_size);
// Reduction within simd group (simd_add isn't supported for uint64/int64 types)
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
// Reduction within simd group (simd_add isn't supported for uint64/int64
// types)
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
// Write simd group reduction results to local memory
@@ -128,7 +129,8 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
// Reduction of simdgroup reduction results within threadgroup.
total_val = lid < simd_per_group ? local_vals[lid] : op.init;
for (uint16_t lane_offset = simd_size/2; lane_offset > 0; lane_offset /= 2) {
for (uint16_t lane_offset = simd_size / 2; lane_offset > 0;
lane_offset /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, lane_offset));
}
@@ -138,31 +140,31 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] \
[[kernel]] void all_reduce<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] \
[[kernel]] void all_reduce_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
@@ -170,11 +172,12 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name ##tname, type, type, op<type>)
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name ##tname, type, type, op<type>)
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
@@ -182,4 +185,4 @@ instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
@@ -12,8 +12,8 @@ using namespace metal;
template <typename T, typename U, typename Op>
[[kernel]] void col_reduce_small(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
@@ -25,7 +25,6 @@ template <typename T, typename U, typename Op>
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]) {
// Appease the compiler
(void)out_size;
@@ -35,15 +34,16 @@ template <typename T, typename U, typename Op>
auto out_idx = tid;
in += elem_to_loc(
out_idx,
shape + non_col_ndim,
strides + non_col_ndim,
ndim - non_col_ndim);
out_idx,
shape + non_col_ndim,
strides + non_col_ndim,
ndim - non_col_ndim);
for(uint i = 0; i < non_col_reductions; i++) {
size_t in_idx = elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
for (uint i = 0; i < non_col_reductions; i++) {
size_t in_idx =
elem_to_loc(i, non_col_shapes, non_col_strides, non_col_ndim);
for(uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
for (uint j = 0; j < reduction_size; j++, in_idx += reduction_stride) {
U val = static_cast<U>(in[in_idx]);
total_val = op(total_val, val);
}
@@ -52,21 +52,21 @@ template <typename T, typename U, typename Op>
out[out_idx] = total_val;
}
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("col_reduce_small_" #name)]] \
[[kernel]] void col_reduce_small<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
@@ -112,39 +112,35 @@ METAL_FUNC U _contiguous_strided_reduce(
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
threadgroup U* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
Op op;
if(out_idx < out_size) {
if (out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
// Write out reduction results generated by threadgroups working on specific
// output element, contiguously.
if (lid.y == 0) {
op.atomic_update(out, val, out_idx);
}
@@ -153,40 +149,36 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void col_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup U *local_data [[threadgroup(0)]],
threadgroup U* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]) {
auto out_idx = tid.x * lsize.x + lid.x;
auto in_idx = elem_to_loc(
out_idx + tid.z * out_size,
shape,
strides,
ndim
);
auto in_idx = elem_to_loc(out_idx + tid.z * out_size, shape, strides, ndim);
if(out_idx < out_size) {
if (out_idx < out_size) {
U val = _contiguous_strided_reduce<T, U, Op, N_READS>(
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
in,
local_data,
in_idx,
reduction_size,
reduction_stride,
tid.xy,
lid.xy,
lsize.xy);
// Write out reduction results generated by threadgroups working on specific output element, contiguously.
// Write out reduction results generated by threadgroups working on specific
// output element, contiguously.
if (lid.y == 0) {
uint tgsize_y = ceildiv(gsize.y, lsize.y);
uint tgsize_z = ceildiv(gsize.z, lsize.z);
@@ -195,52 +187,56 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
}
}
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] \
[[kernel]] void col_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template [[host_name("col_reduce_general_no_atomics_" #name)]] \
[[kernel]] void col_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype *local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template \
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
// clang-format off
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
@@ -250,4 +246,4 @@ instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
@@ -12,22 +12,21 @@ using namespace metal;
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T *out [[buffer(0)]],
device T* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] \
[[kernel]] void init_reduce<otype, op>( \
device otype *out [[buffer(1)]], \
uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name ##tname, type, op<type>)
instantiate_init_reduce(name##tname, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And)
instantiate_init_reduce(orbool_, bool, Or)
instantiate_init_reduce(orbool_, bool, Or) // clang-format on

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/utils.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
@@ -13,8 +13,8 @@ using namespace metal;
// Each thread reduces for one output
template <typename T, typename U, typename Op>
[[kernel]] void row_reduce_general_small(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
@@ -22,22 +22,21 @@ template <typename T, typename U, typename Op>
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]) {
Op op;
uint out_idx = lid;
if(out_idx >= out_size) {
if (out_idx >= out_size) {
return;
}
U total_val = Op::init;
for(short r = 0; r < short(non_row_reductions); r++) {
for (short r = 0; r < short(non_row_reductions); r++) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
for(short i = 0; i < short(reduction_size); i++) {
const device T* in_row = in + in_idx;
for (short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
@@ -48,8 +47,8 @@ template <typename T, typename U, typename Op>
// Each simdgroup reduces for one output
template <typename T, typename U, typename Op>
[[kernel]] void row_reduce_general_med(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
@@ -60,45 +59,42 @@ template <typename T, typename U, typename Op>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
uint out_idx = simd_per_group * tid + simd_group_id;
if(out_idx >= out_size) {
if (out_idx >= out_size) {
return;
}
U total_val = Op::init;
if(short(non_row_reductions) == 1) {
if (short(non_row_reductions) == 1) {
uint in_idx = elem_to_loc(out_idx, shape, strides, ndim);
const device T * in_row = in + in_idx;
const device T* in_row = in + in_idx;
for(short i = simd_lane_id; i < short(reduction_size); i += 32) {
for (short i = simd_lane_id; i < short(reduction_size); i += 32) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
else if (short(non_row_reductions) >= 32) {
for(short r = simd_lane_id; r < short(non_row_reductions); r+=32) {
for (short r = simd_lane_id; r < short(non_row_reductions); r += 32) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
const device T* in_row = in + in_idx;
for(short i = 0; i < short(reduction_size); i++) {
for (short i = 0; i < short(reduction_size); i++) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
}
else {
const short n_reductions = short(reduction_size) * short(non_row_reductions);
const short reductions_per_thread = (n_reductions + simd_size - 1) / simd_size;
const short n_reductions =
short(reduction_size) * short(non_row_reductions);
const short reductions_per_thread =
(n_reductions + simd_size - 1) / simd_size;
const short r_st = simd_lane_id / reductions_per_thread;
const short r_ed = short(non_row_reductions);
@@ -108,54 +104,50 @@ template <typename T, typename U, typename Op>
const short i_ed = short(reduction_size);
const short i_jump = reductions_per_thread;
if(r_st < r_jump) {
for(short r = r_st; r < r_ed; r += r_jump) {
if (r_st < r_jump) {
for (short r = r_st; r < r_ed; r += r_jump) {
uint in_idx = elem_to_loc(out_idx + r * out_size, shape, strides, ndim);
const device T * in_row = in + in_idx;
const device T* in_row = in + in_idx;
for(short i = i_st; i < i_ed; i += i_jump) {
for (short i = i_st; i < i_ed; i += i_jump) {
total_val = op(static_cast<U>(in_row[i]), total_val);
}
}
}
}
total_val = op.simd_reduce(total_val);
if(simd_lane_id == 0) {
if (simd_lane_id == 0) {
out[out_idx] = total_val;
}
}
#define instantiate_row_reduce_small(name, itype, otype, op) \
template[[host_name("row_reduce_general_small_" #name)]] \
[[kernel]] void row_reduce_general_small<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template[[host_name("row_reduce_general_med_" #name)]] \
[[kernel]] void row_reduce_general_med<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
#define instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("row_reduce_general_small_" #name)]] [[kernel]] void \
row_reduce_general_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template [[host_name("row_reduce_general_med_" #name)]] [[kernel]] void \
row_reduce_general_med<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
@@ -217,10 +209,10 @@ METAL_FUNC U per_thread_row_reduce(
return total_val;
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_general(
const device T *in [[buffer(0)]],
device mlx_atomic<U> *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device mlx_atomic<U>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
@@ -233,25 +225,33 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions;
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
in,
reduction_size,
out_size,
shape,
strides,
ndim,
lsize.x,
lid.x,
tid.xy);
total_val = op.simd_reduce(total_val);
// Prepare next level
if (simd_lane_id == 0) {
local_vals[simd_group_id] = total_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Reduction within thread group
// Only needed if multiple simd groups
if(reduction_size > simd_size) {
if (reduction_size > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
total_val = op.simd_reduce(total_val);
}
@@ -261,10 +261,10 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_general_no_atomics(
const device T *in [[buffer(0)]],
device U *out [[buffer(1)]],
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
@@ -278,16 +278,24 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
(void)non_row_reductions;
Op op;
threadgroup U local_vals[simd_size];
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(in, reduction_size, out_size, shape, strides, ndim, lsize.x, lid.x, tid.xy);
U total_val = per_thread_row_reduce<T, U, Op, N_READS>(
in,
reduction_size,
out_size,
shape,
strides,
ndim,
lsize.x,
lid.x,
tid.xy);
// Reduction within simd group - simd_add isn't supported for int64 types
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
@@ -299,9 +307,9 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
// Reduction within thread group
// Only needed if thread group has multiple simd groups
if(ceildiv(reduction_size, N_READS) > simd_size) {
if (ceildiv(reduction_size, N_READS) > simd_size) {
total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init;
for (uint16_t i = simd_size/2; i > 0; i /= 2) {
for (uint16_t i = simd_size / 2; i > 0; i /= 2) {
total_val = op(total_val, simd_shuffle_down(total_val, i));
}
}
@@ -311,61 +319,60 @@ template <typename T, typename U, typename Op, int N_READS=REDUCE_N_READS>
}
}
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("row_reduce_general_" #name)]] \
[[kernel]] void row_reduce_general<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device mlx_atomic<otype> *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("row_reduce_general_no_atomics_" #name)]] \
[[kernel]] void row_reduce_general_no_atomics<itype, otype, op>( \
const device itype *in [[buffer(0)]], \
device otype *out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) template \
[[host_name("row_reduce_general_" #name)]] [[kernel]] void \
row_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) template \
[[host_name("row_reduce_general_no_atomics_" #name)]] [[kernel]] void \
row_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name ##tname, type, type, op<type>)
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name ##tname, type, type, op<type>)
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@@ -8,64 +8,67 @@
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
// clang-format off
#define instantiate_reduce_helper_floats(inst_f, name, op) \
inst_f(name, float16, half, op) inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
inst_f(name, bfloat16, bfloat16_t, op)
#define instantiate_reduce_helper_uints(inst_f, name, op) \
inst_f(name, uint8, uint8_t, op) inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
inst_f(name, uint32, uint32_t, op)
#define instantiate_reduce_helper_ints(inst_f, name, op) \
inst_f(name, int8, int8_t, op) inst_f(name, int16, int16_t, op) \
inst_f(name, int32, int32_t, op)
inst_f(name, int32, int32_t, op)
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) inst_f(name, uint64, uint64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_reduce_helper_ints(inst_f, name, op)
instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_reduce_helper_ints(inst_f, name, op)
#define instantiate_reduce_ops(inst_f, type_f) \
type_f(inst_f, sum, Sum) type_f(inst_f, prod, Prod) \
type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
type_f(inst_f, min_, Min) type_f(inst_f, max_, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper(inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
// clang-format on

View File

@@ -237,13 +237,17 @@ template <typename T, int N_READS = RMS_N_READS>
gw += gid * axis_size + lid * N_READS;
if (lid * N_READS + N_READS <= axis_size) {
for (int i = 0; i < N_READS; i++) {
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
gx[i] = static_cast<T>(
thread_g[i] * thread_w[i] * normalizer -
thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
} else {
for (int i = 0; i < N_READS; i++) {
if ((lid * N_READS + i) < axis_size) {
gx[i] = static_cast<T>(thread_g[i] * thread_w[i] * normalizer - thread_x[i] * meangwx * normalizer3);
gx[i] = static_cast<T>(
thread_g[i] * thread_w[i] * normalizer -
thread_x[i] * meangwx * normalizer3);
gw[i] = static_cast<T>(thread_g[i] * thread_x[i] * normalizer);
}
}
@@ -342,7 +346,8 @@ template <typename T, int N_READS = RMS_N_READS>
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gx[i + r] =
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
} else {
@@ -352,7 +357,8 @@ template <typename T, int N_READS = RMS_N_READS>
float wi = w[w_stride * (i + r)];
float gi = g[i + r];
gx[i + r] = static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gx[i + r] =
static_cast<T>(gi * wi * normalizer - xi * meangwx * normalizer3);
gw[i + r] = static_cast<T>(gi * xi * normalizer);
}
}
@@ -431,5 +437,4 @@ template <typename T, int N_READS = RMS_N_READS>
instantiate_rms(float32, float)
instantiate_rms(float16, half)
instantiate_rms(bfloat16, bfloat16_t)
// clang-format on
instantiate_rms(bfloat16, bfloat16_t) // clang-format on

View File

@@ -7,8 +7,8 @@
template <typename T, bool traditional, bool forward>
[[kernel]] void rope(
const device T *in [[buffer(0)]],
device T * out [[buffer(1)]],
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const size_t strides[3],
constant const size_t out_strides[3],
constant const int& offset,
@@ -20,12 +20,15 @@ template <typename T, bool traditional, bool forward>
uint in_index_1, in_index_2;
uint out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
pos.z * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + pos.z * out_strides[0];
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
pos.z * out_strides[0];
out_index_2 = out_index_1 + grid.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + pos.z * strides[0];
in_index_2 = in_index_1 + grid.x * strides[2];
@@ -57,18 +60,19 @@ template <typename T, bool traditional, bool forward>
}
#define instantiate_rope(name, type, traditional, forward) \
template [[host_name("rope_" #name)]] \
[[kernel]] void rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
template [[host_name("rope_" #name)]] [[kernel]] void \
rope<type, traditional, forward>( \
const device type* in [[buffer(0)]], \
device type* out [[buffer(1)]], \
constant const size_t strides[3], \
constant const size_t out_strides[3], \
constant const int& offset, \
constant const float& base, \
constant const float& scale, \
uint3 pos [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
// clang-format off
instantiate_rope(traditional_float16, half, true, true)
instantiate_rope(traditional_bfloat16, bfloat16_t, true, true)
instantiate_rope(traditional_float32, float, true, true)
@@ -80,4 +84,4 @@ instantiate_rope(vjp_traditional_bfloat16, bfloat16_t, true, false)
instantiate_rope(vjp_traditional_float32, float, true, false)
instantiate_rope(vjp_float16, half, false, false)
instantiate_rope(vjp_bfloat16, bfloat16_t, false, false)
instantiate_rope(vjp_float32, float, false, false)
instantiate_rope(vjp_float32, float, false, false) // clang-format on

View File

@@ -1,451 +1,551 @@
#include <metal_stdlib>
#include <metal_simdgroup>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
using namespace metal;
template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_t NSIMDGROUPS>
[[kernel]] void fast_inference_sdpa_compute_partials_template(const device T *Q [[buffer(0)]],
const device T *K [[buffer(1)]],
const device T *V [[buffer(2)]],
const device uint64_t& L [[buffer(3)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
constexpr const uint iter_offset = NSIMDGROUPS * 4;
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
uint kv_head_offset_factor = tid.x;
if(is_gqa) {
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
kv_head_offset_factor = tid.x / q_kv_head_ratio;
template <
typename T,
typename T2,
typename T4,
uint16_t TILE_SIZE_CONST,
uint16_t NSIMDGROUPS>
[[kernel]] void fast_inference_sdpa_compute_partials_template(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
const device uint64_t& L [[buffer(3)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]],
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {
constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
constexpr const uint iter_offset = NSIMDGROUPS * 4;
const bool is_gqa = params.N_KV_HEADS != params.N_Q_HEADS;
uint kv_head_offset_factor = tid.x;
if (is_gqa) {
int q_kv_head_ratio = params.N_Q_HEADS / params.N_KV_HEADS;
kv_head_offset_factor = tid.x / q_kv_head_ratio;
}
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP =
TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR *
SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) *
NSIMDGROUPS;
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
#pragma clang loop unroll(full)
for (uint i = 0; i < 8; i++) {
smemFlush
[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP +
i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// TODO: multiple query sequence length for speculative decoding
const uint tgroup_query_head_offset =
tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
const device T* baseK =
K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
const device T* baseQ = Q + tgroup_query_head_offset;
device T4* simdgroupQueryData = (device T4*)baseQ;
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
float threadAccum[ACCUM_PER_GROUP];
#pragma clang loop unroll(full)
for (size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP;
threadAccumIndex++) {
threadAccum[threadAccumIndex] = -INFINITY;
}
uint KROW_ACCUM_INDEX = 0;
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
const bool LAST_TILE_ALIGNED =
(SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
T4 thread_data_x4;
T4 thread_data_y4;
if (!LAST_TILE || LAST_TILE_ALIGNED) {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
#pragma clang loop unroll(full)
for (size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseK + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
constexpr const uint16_t P_VEC4 = TILE_SIZE_CONST / NSIMDGROUPS / 4;
constexpr const size_t MATRIX_LOADS_PER_SIMDGROUP = TILE_SIZE_CONST / (SIMDGROUP_MATRIX_LOAD_FACTOR * NSIMDGROUPS);
constexpr const size_t MATRIX_COLS = DK / SIMDGROUP_MATRIX_LOAD_FACTOR;
constexpr const uint totalSmemV = SIMDGROUP_MATRIX_LOAD_FACTOR * SIMDGROUP_MATRIX_LOAD_FACTOR * (MATRIX_LOADS_PER_SIMDGROUP + 1) * NSIMDGROUPS;
} else {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
const uint START_ROW = tid.y * TILE_SIZE_CONST;
const device T* baseKThisHead =
K + tgroup_k_batch_offset + tgroup_k_head_offset;
threadgroup T4* smemFlush = (threadgroup T4*)threadgroup_block;
#pragma clang loop unroll(full)
for(uint i = 0; i < 8; i++) {
smemFlush[simd_lane_id + simd_group_id * THREADS_PER_SIMDGROUP + i * NSIMDGROUPS * THREADS_PER_SIMDGROUP] = T4(0.f);
for (size_t KROW = START_ROW + simd_group_id; KROW < L;
KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// TODO: multiple query sequence length for speculative decoding
const uint tgroup_query_head_offset = tid.x * DK + tid.z * (params.N_Q_HEADS * DK);
}
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
const uint tgroup_k_head_offset = kv_head_offset_factor * DK * L;
const uint tgroup_k_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const uint tgroup_k_batch_offset = tid.z * L * params.N_KV_HEADS * DK;
const device T* baseK = K + tgroup_k_batch_offset + tgroup_k_tile_offset + tgroup_k_head_offset;
const device T* baseQ = Q + tgroup_query_head_offset;
device T4* simdgroupQueryData = (device T4*)baseQ;
constexpr const size_t ACCUM_PER_GROUP = TILE_SIZE_CONST / NSIMDGROUPS;
float threadAccum[ACCUM_PER_GROUP];
#pragma clang loop unroll(full)
for(size_t threadAccumIndex = 0; threadAccumIndex < ACCUM_PER_GROUP; threadAccumIndex++) {
threadAccum[threadAccumIndex] = -INFINITY;
#pragma clang loop unroll(full)
for (size_t i = 0; i < P_VEC4; i++) {
thread_data_x4 =
T4(threadAccum[4 * i],
threadAccum[4 * i + 1],
threadAccum[4 * i + 2],
threadAccum[4 * i + 3]);
simdgroup_barrier(mem_flags::mem_none);
thread_data_y4 = simd_sum(thread_data_x4);
if (simd_lane_id == 0) {
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
}
}
uint KROW_ACCUM_INDEX = 0;
threadgroup_barrier(mem_flags::mem_threadgroup);
const int32_t SEQUENCE_LENGTH_LESS_TILE_SIZE = L - TILE_SIZE_CONST;
const bool LAST_TILE = (tid.y + 1) * TILE_SIZE_CONST >= L;
const bool LAST_TILE_ALIGNED = (SEQUENCE_LENGTH_LESS_TILE_SIZE == int32_t(tid.y * TILE_SIZE_CONST));
float groupMax;
float lse = 0.f;
T4 thread_data_x4;
T4 thread_data_y4;
if(!LAST_TILE || LAST_TILE_ALIGNED) {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
#pragma clang loop unroll(full)
for(size_t KROW = simd_group_id; KROW < TILE_SIZE_CONST; KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseK + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
constexpr const size_t ACCUM_ARRAY_LENGTH =
TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
float4 pvals[ACCUM_ARRAY_LENGTH];
#pragma clang loop unroll(full)
for (uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH;
accum_array_iter++) {
pvals[accum_array_iter] = float4(-INFINITY);
}
if (TILE_SIZE_CONST == 64) {
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
float2 vals = smemPtrFlt2[simd_lane_id];
vals *= params.INV_ALPHA;
float maxval = max(vals.x, vals.y);
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float2 expf_shifted = exp(vals - groupMax);
float sumExpLocal = expf_shifted.x + expf_shifted.y;
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
float2 local_p_hat = expf_shifted / tgroupExpSum;
pvals[0].x = local_p_hat.x;
pvals[0].y = local_p_hat.y;
smemPtrFlt2[simd_lane_id] = float2(0.f);
}
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
if (TILE_SIZE_LARGER_THAN_64) {
float maxval = -INFINITY;
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
vals *= params.INV_ALPHA;
pvals[i] = vals;
maxval = fmax3(vals.x, vals.y, maxval);
maxval = fmax3(vals.z, vals.w, maxval);
}
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float sumExpLocal = 0.f;
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = exp(pvals[i] - groupMax);
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
}
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
#pragma clang loop unroll(full)
for (int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = pvals[i] / tgroupExpSum;
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
}
}
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
const size_t v_head_offset = kv_head_offset_factor * L * DK;
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
device T* baseV = (device T*)V + v_offset;
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
if (!LAST_TILE || LAST_TILE_ALIGNED) {
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint matrix_load_loop_iter = 0;
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
for (size_t tile_start = simd_group_id;
tile_start < TILE_SIZE_CONST_DIV_8;
tile_start += NSIMDGROUPS) {
simdgroup_matrix<T, 8, 8> tmp;
ulong simdgroup_matrix_offset =
matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR +
simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
matrix_load_loop_iter++;
};
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
uint loop_iter = 0;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
#pragma clang loop unroll(full)
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
} else {
thread_data_x4 = *(simdgroupQueryData + simd_lane_id);
const uint START_ROW = tid.y * TILE_SIZE_CONST;
const device T* baseKThisHead = K + tgroup_k_batch_offset + tgroup_k_head_offset;
}
for(size_t KROW = START_ROW + simd_group_id; KROW < L; KROW += NSIMDGROUPS) {
const uint KROW_OFFSET = KROW * DK;
const device T* baseKRow = baseKThisHead + KROW_OFFSET;
device T4* keysData = (device T4*)baseKRow;
thread_data_y4 = *(keysData + simd_lane_id);
T kq_scalar = dot(thread_data_x4, thread_data_y4);
threadAccum[KROW_ACCUM_INDEX] = float(kq_scalar);
KROW_ACCUM_INDEX++;
if (TILE_SIZE_CONST > 64) {
constexpr const size_t TILE_SIZE_CONST_DIV_128 =
(TILE_SIZE_CONST + 1) / 128;
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_iter = 0;
for (size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR;
row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
T row_sum = 0.f;
for (size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[i]);
T val = dot(p_local, v_local);
row_sum += val;
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] =
float(row_sum);
loop_iter++;
}
}
}
threadgroup float* smemP = (threadgroup float*)threadgroup_block;
} else {
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
constexpr const int ROWS_PER_ITER = 8;
#pragma clang loop unroll(full)
for (size_t col = 0; col < MATRIX_COLS; col++) {
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
int32_t tile_start;
for (tile_start =
START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
tile_start < MAX_START_ROW;
tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
simdgroup_matrix<T, 8, 8> tmp;
ulong2 matrixOrigin =
ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
simdgroup_load(
tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(
tmp,
smemV,
elemsPerRowSmem,
matrixOriginSmem,
/* transpose */ false);
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
};
#pragma clang loop unroll(full)
for(size_t i = 0; i < P_VEC4; i++) {
thread_data_x4 = T4(threadAccum[4 * i], threadAccum[4 * i + 1], threadAccum[4 * i + 2], threadAccum[4 * i + 3]);
simdgroup_barrier(mem_flags::mem_none);
thread_data_y4 = simd_sum(thread_data_x4);
if(simd_lane_id == 0) {
const uint base_smem_p_offset = i * iter_offset + simd_group_id;
smemP[base_smem_p_offset + NSIMDGROUPS * 0] = float(thread_data_y4.x);
smemP[base_smem_p_offset + NSIMDGROUPS * 1] = float(thread_data_y4.y);
smemP[base_smem_p_offset + NSIMDGROUPS * 2] = float(thread_data_y4.z);
smemP[base_smem_p_offset + NSIMDGROUPS * 3] = float(thread_data_y4.w);
tile_start =
((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
const int32_t INT_L = int32_t(L);
for (int row_index = tile_start + simd_group_id; row_index < INT_L;
row_index += NSIMDGROUPS) {
if (simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
const uint elems_per_row_gmem = DK;
const uint col_index_v_gmem =
col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
const uint row_index_v_gmem = row_index;
const uint elems_per_row_smem = TILE_SIZE_CONST;
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
const uint row_index_v_smem = simd_lane_id;
const uint scalar_offset_gmem =
row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
const uint scalar_offset_smem =
row_index_v_smem * elems_per_row_smem + col_index_v_smem;
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
smemV[scalar_offset_smem] = vdata;
smem_col_index += NSIMDGROUPS;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
float groupMax;
float lse = 0.f;
constexpr const size_t THREADS_PER_THREADGROUP_TIMES_4 = 4 * 32;
constexpr const size_t ACCUM_ARRAY_LENGTH = TILE_SIZE_CONST / THREADS_PER_THREADGROUP_TIMES_4 + 1;
float4 pvals[ACCUM_ARRAY_LENGTH];
#pragma clang loop unroll(full)
for(uint accum_array_iter = 0; accum_array_iter < ACCUM_ARRAY_LENGTH; accum_array_iter++) {
pvals[accum_array_iter] = float4(-INFINITY);
}
if (TILE_SIZE_CONST == 64) {
threadgroup float2* smemPtrFlt2 = (threadgroup float2*)threadgroup_block;
float2 vals = smemPtrFlt2[simd_lane_id];
vals *= params.INV_ALPHA;
float maxval = max(vals.x, vals.y);
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
float2 expf_shifted = exp(vals - groupMax);
float sumExpLocal = expf_shifted.x + expf_shifted.y;
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
float2 local_p_hat = expf_shifted / tgroupExpSum;
pvals[0].x = local_p_hat.x;
pvals[0].y = local_p_hat.y;
smemPtrFlt2[simd_lane_id] = float2(0.f);
}
constexpr const bool TILE_SIZE_LARGER_THAN_64 = TILE_SIZE_CONST > 64;
constexpr const int TILE_SIZE_ITERS_128 = TILE_SIZE_CONST / 128;
if (TILE_SIZE_LARGER_THAN_64) {
float maxval = -INFINITY;
threadgroup float4* smemPtrFlt4 = (threadgroup float4*)threadgroup_block;
#pragma clang loop unroll(full)
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
float4 vals = smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP];
vals *= params.INV_ALPHA;
pvals[i] = vals;
maxval = fmax3(vals.x, vals.y, maxval);
maxval = fmax3(vals.z, vals.w, maxval);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
for (size_t smem_row_index = simd_group_id;
smem_row_index < ROWS_PER_ITER;
smem_row_index += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[smem_row_index] = float(row_sum);
}
simdgroup_barrier(mem_flags::mem_none);
groupMax = simd_max(maxval);
}
float sumExpLocal = 0.f;
#pragma clang loop unroll(full)
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = exp(pvals[i] - groupMax);
sumExpLocal += pvals[i].x + pvals[i].y + pvals[i].z + pvals[i].w;
}
simdgroup_barrier(mem_flags::mem_none);
float tgroupExpSum = simd_sum(sumExpLocal);
lse = log(tgroupExpSum);
#pragma clang loop unroll(full)
for(int i = 0; i < TILE_SIZE_ITERS_128; i++) {
pvals[i] = pvals[i] / tgroupExpSum;
smemPtrFlt4[simd_lane_id + i * THREADS_PER_SIMDGROUP] = float4(0.f);
if (TILE_SIZE_CONST > 64) {
threadgroup float* oPartialSmem =
smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_count = 0;
for (size_t row_index = simd_group_id; row_index < ROWS_PER_ITER;
row_index += NSIMDGROUPS) {
T row_sum = 0.f;
for (size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128;
tile_iters++) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local =
*(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[tile_iters]);
row_sum += dot(p_local, v_local);
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] =
float(row_sum);
loop_count++;
}
}
}
}
threadgroup T* smemV = (threadgroup T*)threadgroup_block;
threadgroup_barrier(mem_flags::mem_threadgroup);
const size_t v_batch_offset = tid.z * params.N_KV_HEADS * L * DK;
const size_t v_head_offset = kv_head_offset_factor * L * DK;
if (simd_group_id == 0) {
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
float4 vals = *(oPartialVec4 + simd_lane_id);
device float* oPartialGmem =
O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
oPartialGmemVec4[simd_lane_id] = vals;
}
const size_t v_tile_offset = tid.y * TILE_SIZE_CONST * DK;
const size_t v_offset = v_batch_offset + v_head_offset + v_tile_offset;
device T* baseV = (device T*)V + v_offset;
threadgroup float* smemOpartial = (threadgroup float*)(smemV + totalSmemV);
if (!LAST_TILE || LAST_TILE_ALIGNED) {
#pragma clang loop unroll(full)
for(size_t col = 0; col < MATRIX_COLS; col++) {
uint matrix_load_loop_iter = 0;
constexpr const size_t TILE_SIZE_CONST_DIV_8 = TILE_SIZE_CONST / 8;
for(size_t tile_start = simd_group_id; tile_start < TILE_SIZE_CONST_DIV_8; tile_start += NSIMDGROUPS) {
simdgroup_matrix<T, 8, 8> tmp;
ulong simdgroup_matrix_offset = matrix_load_loop_iter * NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, simdgroup_matrix_offset);
simdgroup_load(tmp, baseV, DK, matrixOrigin, true);
const ulong2 matrixOriginSmem = ulong2(simdgroup_matrix_offset, 0);
const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, false);
matrix_load_loop_iter++;
};
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
uint loop_iter = 0;
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
#pragma clang loop unroll(full)
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
loop_iter++;
}
}
if (TILE_SIZE_CONST > 64) {
constexpr const size_t TILE_SIZE_CONST_DIV_128 = (TILE_SIZE_CONST + 1) / 128;
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_iter = 0;
for(size_t row = simd_group_id; row < SIMDGROUP_MATRIX_LOAD_FACTOR; row += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row);
T row_sum = 0.f;
for(size_t i = 0; i < TILE_SIZE_CONST_DIV_128; i++) {
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + i * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[i]);
T val = dot(p_local, v_local);
row_sum += val;
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + loop_iter * NSIMDGROUPS] = float(row_sum);
loop_iter++;
}
}
}
} else {
const int32_t START_ROW = tid.y * TILE_SIZE_CONST;
const int32_t MAX_START_ROW = L - SIMDGROUP_MATRIX_LOAD_FACTOR + 1;
const device T* baseVThisHead = V + v_batch_offset + v_head_offset;
constexpr const int ROWS_PER_ITER = 8;
#pragma clang loop unroll(full)
for(size_t col = 0; col < MATRIX_COLS; col++) {
uint smem_col_index = simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR;
int32_t tile_start;
for(tile_start = START_ROW + simd_group_id * SIMDGROUP_MATRIX_LOAD_FACTOR; tile_start < MAX_START_ROW; tile_start += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR) {
simdgroup_matrix<T, 8, 8> tmp;
ulong2 matrixOrigin = ulong2(col * SIMDGROUP_MATRIX_LOAD_FACTOR, tile_start);
simdgroup_load(tmp, baseVThisHead, DK, matrixOrigin, /* transpose */ true);
const ulong2 matrixOriginSmem = ulong2(smem_col_index, 0);
constexpr const ulong elemsPerRowSmem = TILE_SIZE_CONST;
simdgroup_store(tmp, smemV, elemsPerRowSmem, matrixOriginSmem, /* transpose */ false);
smem_col_index += NSIMDGROUPS * SIMDGROUP_MATRIX_LOAD_FACTOR;
};
tile_start = ((L / SIMDGROUP_MATRIX_LOAD_FACTOR) * SIMDGROUP_MATRIX_LOAD_FACTOR);
const int32_t INT_L = int32_t(L);
for(int row_index = tile_start + simd_group_id ; row_index < INT_L; row_index += NSIMDGROUPS) {
if(simd_lane_id < SIMDGROUP_MATRIX_LOAD_FACTOR) {
const uint elems_per_row_gmem = DK;
const uint col_index_v_gmem = col * SIMDGROUP_MATRIX_LOAD_FACTOR + simd_lane_id;
const uint row_index_v_gmem = row_index;
const uint elems_per_row_smem = TILE_SIZE_CONST;
const uint col_index_v_smem = row_index % TILE_SIZE_CONST;
const uint row_index_v_smem = simd_lane_id;
const uint scalar_offset_gmem = row_index_v_gmem * elems_per_row_gmem + col_index_v_gmem;
const uint scalar_offset_smem = row_index_v_smem * elems_per_row_smem + col_index_v_smem;
T vdata = T(*(baseVThisHead + scalar_offset_gmem));
smemV[scalar_offset_smem] = vdata;
smem_col_index += NSIMDGROUPS;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (TILE_SIZE_CONST == 64) {
T2 local_p_hat = T2(pvals[0].x, pvals[0].y);
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
for(size_t smem_row_index = simd_group_id;
smem_row_index < ROWS_PER_ITER; smem_row_index += NSIMDGROUPS) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * smem_row_index);
threadgroup T2* smemV2 = (threadgroup T2*)smemV_row;
T2 v_local = *(smemV2 + simd_lane_id);
T val = dot(local_p_hat, v_local);
simdgroup_barrier(mem_flags::mem_none);
T row_sum = simd_sum(val);
oPartialSmem[smem_row_index] = float(row_sum);
}
}
if (TILE_SIZE_CONST > 64) {
threadgroup float* oPartialSmem = smemOpartial + SIMDGROUP_MATRIX_LOAD_FACTOR * col;
uint loop_count = 0;
for(size_t row_index = simd_group_id;
row_index < ROWS_PER_ITER; row_index += NSIMDGROUPS) {
T row_sum = 0.f;
for(size_t tile_iters = 0; tile_iters < TILE_SIZE_ITERS_128; tile_iters++) {
threadgroup T* smemV_row = smemV + (TILE_SIZE_CONST * row_index);
threadgroup T4* smemV2 = (threadgroup T4*)smemV_row;
T4 v_local = *(smemV2 + simd_lane_id + tile_iters * THREADS_PER_SIMDGROUP);
T4 p_local = T4(pvals[tile_iters]);
row_sum += dot(p_local, v_local);
}
simdgroup_barrier(mem_flags::mem_none);
row_sum = simd_sum(row_sum);
oPartialSmem[simd_group_id + NSIMDGROUPS * loop_count] = float(row_sum);
loop_count++;
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if(simd_group_id == 0) {
threadgroup float4* oPartialVec4 = (threadgroup float4*)smemOpartial;
float4 vals = *(oPartialVec4 + simd_lane_id);
device float* oPartialGmem = O_partials + tid.x * DK * params.KV_TILES + tid.y * DK;
device float4* oPartialGmemVec4 = (device float4*)oPartialGmem;
oPartialGmemVec4[simd_lane_id] = vals;
}
if(simd_group_id == 0 && simd_lane_id == 0) {
const uint tileIndex = tid.y;
const uint gmem_partial_scalar_offset = tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES + tileIndex;
p_lse[gmem_partial_scalar_offset] = lse;
p_maxes[gmem_partial_scalar_offset] = groupMax;
}
if (simd_group_id == 0 && simd_lane_id == 0) {
const uint tileIndex = tid.y;
const uint gmem_partial_scalar_offset =
tid.z * params.N_Q_HEADS * params.KV_TILES + tid.x * params.KV_TILES +
tileIndex;
p_lse[gmem_partial_scalar_offset] = lse;
p_maxes[gmem_partial_scalar_offset] = groupMax;
}
}
#define instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, nsimdgroups) \
template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_size "_" #nsimdgroups )]] \
[[kernel]] void fast_inference_sdpa_compute_partials_template<itype, itype2, itype4, tile_size, nsimdgroups>( \
const device itype *Q [[buffer(0)]], \
const device itype *K [[buffer(1)]], \
const device itype *V [[buffer(2)]], \
const device uint64_t& L [[buffer(3)]], \
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);
#define instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, nsimdgroups) \
template [[host_name("fast_inference_sdpa_compute_partials_" #itype \
"_" #tile_size "_" #nsimdgroups)]] [[kernel]] void \
fast_inference_sdpa_compute_partials_template< \
itype, \
itype2, \
itype4, \
tile_size, \
nsimdgroups>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
const device uint64_t& L [[buffer(3)]], \
const device MLXScaledDotProductAttentionParams& params [[buffer(4)]], \
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype* threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);
// clang-format off
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper( \
itype, itype2, itype4, tile_size) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 4) \
instantiate_fast_inference_sdpa_to_partials_kernel( \
itype, itype2, itype4, tile_size, 8) // clang-format on
#define instantiate_fast_inference_sdpa_to_partials_shapes_helper(itype, itype2, itype4, tile_size) \
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 4) \
instantiate_fast_inference_sdpa_to_partials_kernel(itype, itype2, itype4, tile_size, 8) \
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(float, float2, float4, 512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(half, half2, half4, 512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
float,
float2,
float4,
512);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
64);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
128);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
256);
instantiate_fast_inference_sdpa_to_partials_shapes_helper(
half,
half2,
half4,
512);
template <typename T>
void fast_inference_sdpa_reduce_tiles_template(
const device float *O_partials [[buffer(0)]],
const device float *p_lse[[buffer(1)]],
const device float *p_maxes [[buffer(2)]],
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device T* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
constexpr const int DK = 128;
const ulong offset_rows =
tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
const device float* p_lse_row = p_lse + offset_rows;
const device float* p_rowmax_row = p_maxes + offset_rows;
// reserve some number of registers. this constitutes an assumption on max
// value of KV TILES.
constexpr const uint8_t reserve = 128;
float p_lse_regs[reserve];
float p_rowmax_regs[reserve];
float weights[reserve];
constexpr const int DK = 128;
const ulong offset_rows = tid.z * params.KV_TILES * params.N_Q_HEADS + tid.x * params.KV_TILES;
const device float* p_lse_row = p_lse + offset_rows;
const device float* p_rowmax_row = p_maxes + offset_rows;
// reserve some number of registers. this constitutes an assumption on max value of KV TILES.
constexpr const uint8_t reserve = 128;
float p_lse_regs[reserve];
float p_rowmax_regs[reserve];
float weights[reserve];
float true_max = -INFINITY;
for (size_t i = 0; i < params.KV_TILES; i++) {
p_lse_regs[i] = float(*(p_lse_row + i));
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
true_max = fmax(p_rowmax_regs[i], true_max);
weights[i] = exp(p_lse_regs[i]);
}
float true_max = -INFINITY;
for(size_t i = 0; i < params.KV_TILES; i++) {
p_lse_regs[i] = float(*(p_lse_row + i));
p_rowmax_regs[i] = float(*(p_rowmax_row + i));
true_max = fmax(p_rowmax_regs[i], true_max);
weights[i] = exp(p_lse_regs[i]);
}
float denom = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
weights[i] *= exp(p_rowmax_regs[i] - true_max);
denom += weights[i];
}
float denom = 0.f;
for(size_t i = 0; i < params.KV_TILES; i++) {
weights[i] *= exp(p_rowmax_regs[i]-true_max);
denom += weights[i];
}
const device float* O_partials_with_offset = O_partials +
tid.z * params.N_Q_HEADS * DK * params.KV_TILES +
tid.x * DK * params.KV_TILES;
const device float* O_partials_with_offset = O_partials + tid.z * params.N_Q_HEADS * DK * params.KV_TILES + tid.x * DK * params.KV_TILES;
float o_value = 0.f;
for(size_t i = 0; i < params.KV_TILES; i++) {
float val = *(O_partials_with_offset + i * DK + lid.x);
o_value += val * weights[i] / denom;
}
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
O_gmem[lid.x] = T(o_value);
return;
float o_value = 0.f;
for (size_t i = 0; i < params.KV_TILES; i++) {
float val = *(O_partials_with_offset + i * DK + lid.x);
o_value += val * weights[i] / denom;
}
device T* O_gmem = O + tid.z * params.N_Q_HEADS * DK + tid.x * DK;
O_gmem[lid.x] = T(o_value);
return;
}
kernel void fast_inference_sdpa_reduce_tiles_float(
const device float *O_partials [[buffer(0)]],
const device float *p_lse[[buffer(1)]],
const device float *p_maxes [[buffer(2)]],
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device float* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]])
{
fast_inference_sdpa_reduce_tiles_template<float>(O_partials, p_lse, p_maxes, params,
O, tid, lid);
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<float>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}
kernel void fast_inference_sdpa_reduce_tiles_half(
const device float *O_partials [[buffer(0)]],
const device float *p_lse[[buffer(1)]],
const device float *p_maxes [[buffer(2)]],
const device float* O_partials [[buffer(0)]],
const device float* p_lse [[buffer(1)]],
const device float* p_maxes [[buffer(2)]],
const device MLXScaledDotProductAttentionParams& params [[buffer(3)]],
device half* O [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]])
{
fast_inference_sdpa_reduce_tiles_template<half>(O_partials, p_lse, p_maxes, params,
O, tid, lid);
uint3 lid [[thread_position_in_threadgroup]]) {
fast_inference_sdpa_reduce_tiles_template<half>(
O_partials, p_lse, p_maxes, params, O, tid, lid);
}

View File

@@ -54,7 +54,7 @@ struct CumProd<bool> {
}
bool simd_scan(bool x) {
for (int i=1; i<=16; i*=2) {
for (int i = 1; i <= 16; i *= 2) {
bool other = simd_shuffle_up(x, i);
x &= other;
}
@@ -77,7 +77,7 @@ struct CumMax {
}
U simd_scan(U x) {
for (int i=1; i<=16; i*=2) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x >= other) ? x : other;
}
@@ -100,7 +100,7 @@ struct CumMin {
}
U simd_scan(U x) {
for (int i=1; i<=16; i*=2) {
for (int i = 1; i <= 16; i *= 2) {
U other = simd_shuffle_up(x, i);
x = (x <= other) ? x : other;
}
@@ -114,54 +114,60 @@ struct CumMin {
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T * input) {
inline void load_unsafe(U values[N_READS], const device T* input) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
values[N_READS-i-1] = input[i];
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] = input[i];
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
values[i] = input[i];
}
}
}
template <typename T, typename U, int N_READS, bool reverse>
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) {
inline void load_safe(
U values[N_READS],
const device T* input,
int start,
int total,
U init) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init;
for (int i = 0; i < N_READS; i++) {
values[N_READS - i - 1] =
(start + N_READS - i - 1 < total) ? input[i] : init;
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
values[i] = (start + i < total) ? input[i] : init;
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_unsafe(U values[N_READS], device U * out) {
inline void write_unsafe(U values[N_READS], device U* out) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
out[i] = values[N_READS-i-1];
for (int i = 0; i < N_READS; i++) {
out[i] = values[N_READS - i - 1];
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
out[i] = values[i];
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_safe(U values[N_READS], device U * out, int start, int total) {
inline void write_safe(U values[N_READS], device U* out, int start, int total) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
if (start + N_READS - i - 1 < total) {
out[i] = values[N_READS-i-1];
out[i] = values[N_READS - i - 1];
}
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
if (start + i < total) {
out[i] = values[i];
}
@@ -169,12 +175,17 @@ inline void write_safe(U values[N_READS], device U * out, int start, int total)
}
}
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void contiguous_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t & axis_size [[buffer(2)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
@@ -195,42 +206,51 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
U values[N_READS];
threadgroup U simdgroup_sums[32];
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize)
// Loop over the reduced axis in blocks of size ceildiv(axis_size,
// N_READS*lsize)
// Read block
// Compute inclusive scan of the block
// Compute inclusive scan per thread
// Compute exclusive scan of thread sums in simdgroup
// Write simdgroup sums in SM
// Compute exclusive scan of simdgroup sums
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value
// Compute the output by scanning prefix, prev_simdgroup, prev_thread,
// value
// Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS*lsize); r++) {
for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) {
// Compute the block offset
uint offset = r*lsize*N_READS + lid*N_READS;
uint offset = r * lsize * N_READS + lid * N_READS;
// Read the values
if (reverse) {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS);
load_unsafe<T, U, N_READS, reverse>(
values, in + axis_size - offset - N_READS);
} else {
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init);
load_safe<T, U, N_READS, reverse>(
values,
in + axis_size - offset - N_READS,
offset,
axis_size,
Op::init);
}
} else {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
} else {
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init);
load_safe<T, U, N_READS, reverse>(
values, in + offset, offset, axis_size, Op::init);
}
}
// Compute an inclusive scan per thread
for (int i=1; i<N_READS; i++) {
values[i] = op(values[i], values[i-1]);
for (int i = 1; i < N_READS; i++) {
values[i] = op(values[i], values[i - 1]);
}
// Compute exclusive scan of thread sums
U prev_thread = op.simd_exclusive_scan(values[N_READS-1]);
U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
// Write simdgroup_sums to SM
if (simd_lane_id == simd_size - 1) {
@@ -246,7 +266,7 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute the output
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
values[i] = op(values[i], prefix);
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
values[i] = op(values[i], prev_thread);
@@ -256,18 +276,25 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if (reverse) {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS);
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS);
} else {
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size);
write_safe<U, N_READS, reverse>(
values, out + axis_size - offset - N_READS, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[axis_size-1] = Op::init;
out[axis_size - 1] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS);
write_unsafe<U, N_READS, reverse>(
values, out + axis_size - offset - 1 - N_READS);
} else {
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size);
write_safe<U, N_READS, reverse>(
values,
out + axis_size - offset - 1 - N_READS,
offset + 1,
axis_size);
}
}
} else {
@@ -275,7 +302,8 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset);
} else {
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size);
write_safe<U, N_READS, reverse>(
values, out + offset, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
@@ -284,26 +312,33 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
} else {
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size);
write_safe<U, N_READS, reverse>(
values, out + offset + 1, offset + 1, axis_size);
}
}
}
// Share the prefix
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
simdgroup_sums[0] = values[N_READS-1];
simdgroup_sums[0] = values[N_READS - 1];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
prefix = simdgroup_sums[0];
}
}
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
template <
typename T,
typename U,
typename Op,
int N_READS,
bool inclusive,
bool reverse>
[[kernel]] void strided_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t & axis_size [[buffer(2)]],
const constant size_t & stride [[buffer(3)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
@@ -311,10 +346,10 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
Op op;
// Allocate memory
threadgroup U read_buffer[N_READS*32*32 + N_READS*32];
threadgroup U read_buffer[N_READS * 32 * 32 + N_READS * 32];
U values[N_READS];
U prefix[N_READS];
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
prefix[i] = Op::init;
}
@@ -322,7 +357,7 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
int offset = gid.y * axis_size * stride;
int global_index_x = gid.x * lsize.y * N_READS;
for (uint j=0; j<axis_size; j+=simd_size) {
for (uint j = 0; j < axis_size; j += simd_size) {
// Calculate the indices for the current thread
uint index_y = j + lid.y;
uint check_index_y = index_y;
@@ -333,37 +368,43 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
// Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
in[offset + index_y * stride + index_x + i];
} else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init;
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] =
Op::init;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read strided into registers
for (int i=0; i<N_READS; i++) {
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
for (int i = 0; i < N_READS; i++) {
values[i] =
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
}
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously?
// Do we need the following barrier? Shouldn't all simd threads execute
// simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
values[i] = op.simd_scan(values[i]);
values[i] = op(values[i], prefix[i]);
prefix[i] = simd_shuffle(values[i], simd_size-1);
prefix[i] = simd_shuffle(values[i], simd_size - 1);
}
// Write to SM
for (int i=0; i<N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i];
for (int i = 0; i < N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] =
values[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -371,11 +412,11 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
if (!inclusive) {
if (check_index_y == 0) {
if ((index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
if ((index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
@@ -391,55 +432,60 @@ template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool
}
}
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
for (int i = 0; i < N_READS; i++) {
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
} else {
for (int i=0; i<N_READS; i++) {
for (int i = 0; i < N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
out[offset + index_y * stride + index_x + i] =
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
}
}
}
}
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("contiguous_scan_" #name)]] \
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t & axis_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_contiguous_scan( \
name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("contiguous_scan_" #name)]] [[kernel]] void \
contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& axis_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("strided_scan_" #name)]] \
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t & axis_size [[buffer(2)]], \
const constant size_t & stride [[buffer(3)]], \
uint2 gid [[thread_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]);
#define instantiate_strided_scan( \
name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("strided_scan_" #name)]] [[kernel]] void \
strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& axis_size [[buffer(2)]], \
const constant size_t& stride [[buffer(3)]], \
uint2 gid [[thread_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]);
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
// clang-format off
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) // clang-format on
// clang-format off
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
@@ -491,4 +537,4 @@ instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMi
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin) // clang-format on

View File

@@ -13,67 +13,55 @@ using namespace metal;
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX> \
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_1d_index_impl(
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(
idx_buffers[i][gid.y], out_shape[i]);
auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter_1d_index( \
const device T *updates [[buffer(1)]], \
device mlx_atomic<T> *out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
\
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
out_shape, \
out_strides, \
upd_size, \
idx_buffers, \
gid); \
\
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter_1d_index( \
const device T* updates [[buffer(1)]], \
device mlx_atomic<T>* out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
\
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid); \
}
template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl(
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const constant int *upd_shape [[buffer(3)]],
const constant size_t *upd_strides [[buffer(4)]],
const device T* updates [[buffer(1)]],
device mlx_atomic<T>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int *out_shape [[buffer(7)]],
const constant size_t *out_strides [[buffer(8)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid.y;
auto ind_offset = gid.x;
@@ -86,8 +74,7 @@ METAL_FUNC void scatter_impl(
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], out_shape[ax]);
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
@@ -97,142 +84,134 @@ METAL_FUNC void scatter_impl(
out_idx += out_offset;
}
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
auto upd_idx =
elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx);
}
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter( \
const device T *updates [[buffer(1)]], \
device mlx_atomic<T> *out [[buffer(2)]], \
const constant int *upd_shape [[buffer(3)]], \
const constant size_t *upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int *out_shape [[buffer(7)]], \
const constant size_t *out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int *idx_shapes [[buffer(11)]], \
const constant size_t *idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, \
idx_shapes, \
idx_strides, \
idx_ndim}; \
\
return scatter_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
upd_shape, \
upd_strides, \
upd_ndim, \
upd_size, \
out_shape, \
out_strides, \
out_ndim, \
axes, \
idxs, \
gid); \
}
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter( \
const device T* updates [[buffer(1)]], \
device mlx_atomic<T>* out [[buffer(2)]], \
const constant int* upd_shape [[buffer(3)]], \
const constant size_t* upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int* out_shape [[buffer(7)]], \
const constant size_t* out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int* idx_shapes [[buffer(11)]], \
const constant size_t* idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(IdxT) uint2 gid [[thread_position_in_grid]]) { \
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
\
return scatter_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
upd_shape, \
upd_strides, \
upd_ndim, \
upd_size, \
out_shape, \
out_strides, \
out_ndim, \
axes, \
idxs, \
gid); \
}
#define make_scatter(n) \
make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \
make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n)
#define make_scatter(n) \
make_scatter_impl(IDX_ARG_##n, IDX_ARR_##n) \
make_scatter_1d_index(IDX_ARG_##n, IDX_ARR_##n)
make_scatter(0)
make_scatter(1)
make_scatter(2)
make_scatter(3)
make_scatter(4)
make_scatter(5)
make_scatter(6)
make_scatter(7)
make_scatter(8)
make_scatter(9)
make_scatter(10)
make_scatter(0) make_scatter(1) make_scatter(2) make_scatter(3) make_scatter(4)
make_scatter(5) make_scatter(6) make_scatter(7) make_scatter(8)
make_scatter(9) make_scatter(10)
/////////////////////////////////////////////////////////////////////
// Scatter instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter" name "_" #nidx)]] \
[[kernel]] void scatter<src_t, idx_t, op_t, nidx>( \
const device src_t *updates [[buffer(1)]], \
device mlx_atomic<src_t> *out [[buffer(2)]], \
const constant int *upd_shape [[buffer(3)]], \
const constant size_t *upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int *out_shape [[buffer(7)]], \
const constant size_t *out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int *idx_shapes [[buffer(11)]], \
const constant size_t *idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
template [[host_name("scatter" name "_" #nidx)]] [[kernel]] void \
scatter<src_t, idx_t, op_t, nidx>( \
const device src_t* updates [[buffer(1)]], \
device mlx_atomic<src_t>* out [[buffer(2)]], \
const constant int* upd_shape [[buffer(3)]], \
const constant size_t* upd_strides [[buffer(4)]], \
const constant size_t& upd_ndim [[buffer(5)]], \
const constant size_t& upd_size [[buffer(6)]], \
const constant int* out_shape [[buffer(7)]], \
const constant size_t* out_strides [[buffer(8)]], \
const constant size_t& out_ndim [[buffer(9)]], \
const constant int* axes [[buffer(10)]], \
const constant int* idx_shapes [[buffer(11)]], \
const constant size_t* idx_strides [[buffer(12)]], \
const constant int& idx_ndim [[buffer(13)]], \
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter_1d_index" name "_" #nidx)]] \
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \
const device src_t *updates [[buffer(1)]], \
device mlx_atomic<src_t> *out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter_1d_index" name "_" #nidx)]] [[kernel]] void \
scatter_1d_index<src_t, idx_t, op_t, nidx>( \
const device src_t* updates [[buffer(1)]], \
device mlx_atomic<src_t>* out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
IDX_ARG(idx_t) uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
// clang-format off
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) // clang-format on
// clang-format off
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0) // clang-format on
// clang-format off
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10)
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10) // clang-format on
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
// clang-format off
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>) // clang-format on
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
// clang-format off
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t)
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t) // clang-format on
// clang-format off
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
@@ -254,4 +233,4 @@ instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t)
instantiate_scatter(bfloat16, bfloat16_t) // clang-format on

View File

@@ -198,17 +198,16 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
}
}
// clang-format off
#define instantiate_softmax(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
#define instantiate_softmax(name, itype) \
template [[host_name("softmax_" #name)]] [[kernel]] void \
softmax_single_row<itype>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_" #name)]] [[kernel]] void \
softmax_looped<itype>( \
const device itype* in, \
@@ -220,16 +219,16 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_softmax_precise(name, itype) \
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
#define instantiate_softmax_precise(name, itype) \
template [[host_name("softmax_precise_" #name)]] [[kernel]] void \
softmax_single_row<itype, float>( \
const device itype* in, \
device itype* out, \
constant int& axis_size, \
uint gid [[thread_position_in_grid]], \
uint _lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]); \
template [[host_name("softmax_looped_precise_" #name)]] [[kernel]] void \
softmax_looped<itype, float>( \
const device itype* in, \
@@ -241,9 +240,9 @@ template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
// clang-format off
instantiate_softmax(float32, float)
instantiate_softmax(float16, half)
instantiate_softmax(bfloat16, bfloat16_t)
instantiate_softmax_precise(float16, half)
instantiate_softmax_precise(bfloat16, bfloat16_t)
// clang-format on
instantiate_softmax_precise(bfloat16, bfloat16_t) // clang-format on

View File

@@ -11,7 +11,8 @@
using namespace metal;
// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub
// Based on GPU merge sort algorithm at
// https://github.com/NVIDIA/cccl/tree/main/cub/cub
///////////////////////////////////////////////////////////////////////////////
// Thread-level sort
@@ -43,20 +44,18 @@ struct ThreadSort {
static METAL_FUNC void sort(
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
MLX_MTL_LOOP_UNROLL
for(short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL
for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if(op(vals[j + 1], vals[j])) {
for (short i = 0; i < N_PER_THREAD; ++i) {
MLX_MTL_LOOP_UNROLL
for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) {
if (op(vals[j + 1], vals[j])) {
thread_swap(vals[j + 1], vals[j]);
thread_swap(idxs[j + 1], idxs[j]);
}
}
}
}
};
@@ -72,25 +71,25 @@ template <
short N_PER_THREAD,
typename CompareOp>
struct BlockMergeSort {
using thread_sort_t = ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
using thread_sort_t =
ThreadSort<val_t, idx_t, ARG_SORT, N_PER_THREAD, CompareOp>;
static METAL_FUNC int merge_partition(
const threadgroup val_t* As,
const threadgroup val_t* Bs,
short A_sz,
short B_sz,
short sort_md) {
CompareOp op;
short A_st = max(0, sort_md - B_sz);
short A_ed = min(sort_md, A_sz);
while(A_st < A_ed) {
while (A_st < A_ed) {
short md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if(op(b, a)) {
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
@@ -98,7 +97,6 @@ struct BlockMergeSort {
}
return A_ed;
}
static METAL_FUNC void merge_step(
@@ -110,12 +108,11 @@ struct BlockMergeSort {
short B_sz,
thread val_t (&vals)[N_PER_THREAD],
thread idx_t (&idxs)[N_PER_THREAD]) {
CompareOp op;
short a_idx = 0;
short b_idx = 0;
for(int i = 0; i < N_PER_THREAD; ++i) {
for (int i = 0; i < N_PER_THREAD; ++i) {
auto a = As[a_idx];
auto b = Bs[b_idx];
bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
@@ -126,7 +123,6 @@ struct BlockMergeSort {
b_idx += short(pred);
a_idx += short(!pred);
}
}
static METAL_FUNC void sort(
@@ -134,32 +130,32 @@ struct BlockMergeSort {
threadgroup idx_t* tgp_idxs [[threadgroup(1)]],
int size_sorted_axis,
uint3 lid [[thread_position_in_threadgroup]]) {
// Get thread location
int idx = lid.x * N_PER_THREAD;
// Load from shared memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for(int i = 0; i < N_PER_THREAD; ++i) {
for (int i = 0; i < N_PER_THREAD; ++i) {
thread_vals[i] = tgp_vals[idx + i];
if(ARG_SORT) {
if (ARG_SORT) {
thread_idxs[i] = tgp_idxs[idx + i];
}
}
// Per thread sort
if(idx < size_sorted_axis) {
// Per thread sort
if (idx < size_sorted_axis) {
thread_sort_t::sort(thread_vals, thread_idxs);
}
// Do merges using threadgroup memory
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) {
for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
merge_threads *= 2) {
// Update threadgroup memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if(ARG_SORT) {
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
@@ -167,7 +163,7 @@ struct BlockMergeSort {
// Find location in merge step
int merge_group = lid.x / merge_threads;
int merge_lane = lid.x % merge_threads;
int merge_lane = lid.x % merge_threads;
int sort_sz = N_PER_THREAD * merge_threads;
int sort_st = N_PER_THREAD * merge_threads * merge_group;
@@ -185,16 +181,11 @@ struct BlockMergeSort {
int B_sz = B_ed - B_st;
// Find a partition of merge elements
// Ci = merge(As[partition:], Bs[sort_md - partition:])
// Ci = merge(As[partition:], Bs[sort_md - partition:])
// of size N_PER_THREAD for each merge lane i
// C = [Ci] is sorted
// C = [Ci] is sorted
int sort_md = N_PER_THREAD * merge_lane;
int partition = merge_partition(
As,
Bs,
A_sz,
B_sz,
sort_md);
int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
As += partition;
Bs += sort_md - partition;
@@ -202,27 +193,20 @@ struct BlockMergeSort {
A_sz -= partition;
B_sz -= sort_md - partition;
const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
const threadgroup idx_t* As_idx =
ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
const threadgroup idx_t* Bs_idx =
ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
// Merge starting at the partition and store results in thread registers
merge_step(
As,
Bs,
As_idx,
Bs_idx,
A_sz,
B_sz,
thread_vals,
thread_idxs);
merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
}
// Write out to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
for (int i = 0; i < N_PER_THREAD; ++i) {
tgp_vals[idx + i] = thread_vals[i];
if(ARG_SORT) {
if (ARG_SORT) {
tgp_idxs[idx + i] = thread_idxs[i];
}
}
@@ -235,7 +219,7 @@ struct BlockMergeSort {
template <
typename T,
typename U,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
@@ -244,13 +228,13 @@ struct KernelMergeSort {
using val_t = T;
using idx_t = uint;
using block_merge_sort_t = BlockMergeSort<
val_t,
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
@@ -263,15 +247,15 @@ struct KernelMergeSort {
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
inp += tid.y * stride_segment_axis;
out += tid.y * stride_segment_axis;
// Copy into threadgroup memory
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init);
if(ARG_SORT) {
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis]
: val_t(CompareOp::init);
if (ARG_SORT) {
tgp_idxs[i] = i;
}
}
@@ -284,8 +268,8 @@ struct KernelMergeSort {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) {
if(ARG_SORT) {
for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) {
if (ARG_SORT) {
out[i * stride_sorted_axis] = tgp_idxs[i];
} else {
out[i * stride_sorted_axis] = tgp_vals[i];
@@ -296,7 +280,7 @@ struct KernelMergeSort {
template <
typename T,
typename U,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
@@ -308,12 +292,12 @@ template <
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
if(ARG_SORT) {
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
@@ -339,14 +323,13 @@ template <
tid,
lid);
}
}
constant constexpr const int zero_helper = 0;
template <
typename T,
typename U,
typename U,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
@@ -360,8 +343,8 @@ template <
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using sort_kernel =
KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using val_t = typename sort_kernel::val_t;
using idx_t = typename sort_kernel::idx_t;
@@ -369,7 +352,7 @@ template <
inp += block_idx;
out += block_idx;
if(ARG_SORT) {
if (ARG_SORT) {
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
@@ -395,50 +378,55 @@ template <
tid,
lid);
}
}
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \
[[kernel]] void block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& nc_dim [[buffer(4)]], \
const device int* nc_shape [[buffer(5)]], \
const device size_t* nc_strides [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_block_sort( \
name, itname, itype, otname, otype, arg_sort, bn, tn) \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
block_sort<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& stride_segment_axis [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn \
"_nc")]] [[kernel]] void \
block_sort_nc<itype, otype, arg_sort, bn, tn>( \
const device itype* inp [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant int& size_sorted_axis [[buffer(2)]], \
const constant int& stride_sorted_axis [[buffer(3)]], \
const constant int& nc_dim [[buffer(4)]], \
const device int* nc_shape [[buffer(5)]], \
const device size_t* nc_strides [[buffer(6)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
instantiate_block_sort( \
arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn)
#define instantiate_block_sort_base(itname, itype, bn, tn) \
instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn)
instantiate_block_sort( \
block_merge_sort, itname, itype, itname, itype, false, bn, tn)
// clang-format off
#define instantiate_block_sort_tn(itname, itype, bn) \
instantiate_block_sort_base(itname, itype, bn, 8) \
instantiate_arg_block_sort_base(itname, itype, bn, 8)
instantiate_block_sort_base(itname, itype, bn, 8) \
instantiate_arg_block_sort_base(itname, itype, bn, 8) // clang-format on
// clang-format off
#define instantiate_block_sort_bn(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256) \
instantiate_block_sort_tn(itname, itype, 512)
instantiate_block_sort_bn(uint8, uint8_t)
instantiate_block_sort_bn(uint16, uint16_t)
@@ -448,35 +436,35 @@ instantiate_block_sort_bn(int16, int16_t)
instantiate_block_sort_bn(int32, int32_t)
instantiate_block_sort_bn(float16, half)
instantiate_block_sort_bn(float32, float)
instantiate_block_sort_bn(bfloat16, bfloat16_t)
instantiate_block_sort_bn(bfloat16, bfloat16_t) // clang-format on
// clang-format off
#define instantiate_block_sort_long(itname, itype) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 128) \
instantiate_block_sort_tn(itname, itype, 256)
instantiate_block_sort_long(uint64, uint64_t)
instantiate_block_sort_long(int64, int64_t)
instantiate_block_sort_long(int64, int64_t) // clang-format on
///////////////////////////////////////////////////////////////////////////////
// Multi block merge sort
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
// Multi block merge sort
///////////////////////////////////////////////////////////////////////////////
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
struct KernelMultiBlockMergeSort {
template <
typename val_t,
typename idx_t,
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
struct KernelMultiBlockMergeSort {
using block_merge_sort_t = BlockMergeSort<
val_t,
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
static METAL_FUNC void block_sort(
@@ -489,14 +477,14 @@ struct KernelMultiBlockMergeSort {
threadgroup idx_t* tgp_idxs,
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// tid.y tells us the segment index
int base_idx = tid.x * N_PER_BLOCK;
// Copy into threadgroup memory
for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init);
tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
: val_t(CompareOp::init);
tgp_idxs[i] = idx;
}
@@ -508,9 +496,9 @@ struct KernelMultiBlockMergeSort {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) {
for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if(idx < size_sorted_axis) {
if (idx < size_sorted_axis) {
out_vals[idx] = tgp_vals[i];
out_idxs[idx] = tgp_idxs[i];
}
@@ -523,18 +511,17 @@ struct KernelMultiBlockMergeSort {
int A_sz,
int B_sz,
int sort_md) {
CompareOp op;
int A_st = max(0, sort_md - B_sz);
int A_ed = min(sort_md, A_sz);
while(A_st < A_ed) {
while (A_st < A_ed) {
int md = A_st + (A_ed - A_st) / 2;
auto a = As[md];
auto b = Bs[sort_md - 1 - md];
if(op(b, a)) {
if (op(b, a)) {
A_ed = md;
} else {
A_st = md + 1;
@@ -542,7 +529,6 @@ struct KernelMultiBlockMergeSort {
}
return A_ed;
}
};
@@ -563,8 +549,12 @@ template <
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim);
inp += block_idx;
@@ -575,12 +565,12 @@ template <
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
sort_kernel::block_sort(
inp,
out_vals,
out_idxs,
size_sorted_axis,
stride_sorted_axis,
tgp_vals,
inp,
out_vals,
out_idxs,
size_sorted_axis,
stride_sorted_axis,
tgp_vals,
tgp_idxs,
tid,
lid);
@@ -592,7 +582,8 @@ template <
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partition(
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_partition(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
@@ -601,21 +592,20 @@ template <
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD>;
block_partitions += tid.y * tgp_dims.x;
dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis;
// Find location in merge step
int merge_group = lid.x / merge_tiles;
int merge_lane = lid.x % merge_tiles;
int merge_lane = lid.x % merge_tiles;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
@@ -627,14 +617,9 @@ template <
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st,
dev_vals + B_st,
A_ed - A_st,
B_ed - B_st,
partition_at);
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
block_partitions[lid.x] = A_st + partition;
}
template <
@@ -644,7 +629,8 @@ template <
short BLOCK_THREADS,
short N_PER_THREAD,
typename CompareOp = LessThan<val_t>>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge(
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_merge(
const device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals_in [[buffer(1)]],
const device idx_t* dev_idxs_in [[buffer(2)]],
@@ -655,20 +641,19 @@ template <
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort<
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
val_t,
idx_t,
ARG_SORT,
BLOCK_THREADS,
N_PER_THREAD,
CompareOp>;
using block_sort_t = typename sort_kernel::block_merge_sort_t;
block_partitions += tid.y * (num_tiles + 1);
dev_vals_in += tid.y * size_sorted_axis;
dev_idxs_in += tid.y * size_sorted_axis;
dev_vals_in += tid.y * size_sorted_axis;
dev_idxs_in += tid.y * size_sorted_axis;
dev_vals_out += tid.y * size_sorted_axis;
dev_idxs_out += tid.y * size_sorted_axis;
@@ -680,25 +665,29 @@ template <
int A_st = block_partitions[block_idx + 0];
int A_ed = block_partitions[block_idx + 1];
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st);
int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
int B_ed = min(
size_sorted_axis,
2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
if((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz/2);
if ((block_idx % merge_tiles) == merge_tiles - 1) {
A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
B_ed = min(size_sorted_axis, sort_st + sort_sz);
}
int A_sz = A_ed - A_st;
int B_sz = B_ed - B_st;
// Load from global memory
thread val_t thread_vals[N_PER_THREAD];
thread idx_t thread_idxs[N_PER_THREAD];
for(int i = 0; i < N_PER_THREAD; i++) {
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
if(idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz];
if (idx < (A_sz + B_sz)) {
thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
: dev_vals_in[B_st + idx - A_sz];
thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
: dev_idxs_in[B_st + idx - A_sz];
} else {
thread_vals[i] = CompareOp::init;
thread_idxs[i] = 0;
@@ -709,7 +698,7 @@ template <
threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK];
threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK];
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; i++) {
for (int i = 0; i < N_PER_THREAD; i++) {
int idx = BLOCK_THREADS * i + lid.x;
tgp_vals[idx] = thread_vals[i];
tgp_idxs[idx] = thread_idxs[i];
@@ -720,11 +709,7 @@ template <
int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x));
int A_st_local = block_sort_t::merge_partition(
tgp_vals,
tgp_vals + A_sz,
A_sz,
B_sz,
sort_md_local);
tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
int A_ed_local = A_sz;
int B_st_local = sort_md_local - A_st_local;
@@ -733,7 +718,7 @@ template <
int A_sz_local = A_ed_local - A_st_local;
int B_sz_local = B_ed_local - B_st_local;
// Do merge
// Do merge
block_sort_t::merge_step(
tgp_vals + A_st_local,
tgp_vals + A_ed_local + B_st_local,
@@ -745,61 +730,65 @@ template <
thread_idxs);
threadgroup_barrier(mem_flags::mem_threadgroup);
for(int i = 0; i < N_PER_THREAD; ++i) {
for (int i = 0; i < N_PER_THREAD; ++i) {
int idx = lid.x * N_PER_THREAD;
tgp_vals[idx + i] = thread_vals[i];
tgp_idxs[idx + i] = thread_idxs[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write output
int base_idx = tid.x * sort_kernel::N_PER_BLOCK;
for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) {
for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
int idx = base_idx + i;
if(idx < size_sorted_axis) {
if (idx < size_sorted_axis) {
dev_vals_out[idx] = tgp_vals[i];
dev_idxs_out[idx] = tgp_idxs[i];
}
}
}
#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \
device vtype* out_vals [[buffer(1)]], \
device itype* out_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& stride_sorted_axis [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& merge_tiles [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \
[[kernel]] void mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \
device vtype* dev_vals_out [[buffer(3)]], \
device itype* dev_idxs_out [[buffer(4)]], \
const constant int& size_sorted_axis [[buffer(5)]], \
const constant int& merge_tiles [[buffer(6)]], \
const constant int& num_tiles [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
#define instantiate_multi_block_sort( \
vtname, vtype, itname, itype, arg_sort, bn, tn) \
template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_sort<vtype, itype, arg_sort, bn, tn>( \
const device vtype* inp [[buffer(0)]], \
device vtype* out_vals [[buffer(1)]], \
device itype* out_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& stride_sorted_axis [[buffer(4)]], \
const constant int& nc_dim [[buffer(5)]], \
const device int* nc_shape [[buffer(6)]], \
const device size_t* nc_strides [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); \
template [[host_name("mb_block_partition_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_partition<vtype, itype, arg_sort, bn, tn>( \
device itype * block_partitions [[buffer(0)]], \
const device vtype* dev_vals [[buffer(1)]], \
const device itype* dev_idxs [[buffer(2)]], \
const constant int& size_sorted_axis [[buffer(3)]], \
const constant int& merge_tiles [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 tgp_dims [[threads_per_threadgroup]]); \
template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn \
"_tn" #tn)]] [[kernel]] void \
mb_block_merge<vtype, itype, arg_sort, bn, tn>( \
const device itype* block_partitions [[buffer(0)]], \
const device vtype* dev_vals_in [[buffer(1)]], \
const device itype* dev_idxs_in [[buffer(2)]], \
device vtype* dev_vals_out [[buffer(3)]], \
device itype* dev_idxs_out [[buffer(4)]], \
const constant int& size_sorted_axis [[buffer(5)]], \
const constant int& merge_tiles [[buffer(6)]], \
const constant int& num_tiles [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_multi_block_sort_base(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8)
@@ -811,10 +800,11 @@ instantiate_multi_block_sort_base(int16, int16_t)
instantiate_multi_block_sort_base(int32, int32_t)
instantiate_multi_block_sort_base(float16, half)
instantiate_multi_block_sort_base(float32, float)
instantiate_multi_block_sort_base(bfloat16, bfloat16_t)
instantiate_multi_block_sort_base(bfloat16, bfloat16_t) // clang-format on
// clang-format off
#define instantiate_multi_block_sort_long(vtname, vtype) \
instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8)
instantiate_multi_block_sort_long(uint64, uint64_t)
instantiate_multi_block_sort_long(int64, int64_t)
instantiate_multi_block_sort_long(int64, int64_t) // clang-format on

View File

@@ -4,21 +4,23 @@
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d(
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
int N_CHANNELS = 0,
bool SMALL_FILTER = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
@@ -28,12 +30,10 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
using namespace mlx::steel;
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
@@ -47,46 +47,64 @@ template <typename T,
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
// Input loader
using loader_a_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_a>,
// Go to small channel specialization
Conv2DInputBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_a>,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Else go to general loader
typename metal::conditional_t<
// Check if filter size is small enough
SMALL_FILTER,
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T, BM, BN, BK, tgp_size, tgp_padding_a>
>
>;
// Go to small filter specialization
Conv2DInputBlockLoaderSmallFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>,
// Else go to large filter generalization
Conv2DInputBlockLoaderLargeFilter<
T,
BM,
BN,
BK,
tgp_size,
tgp_padding_a>>>;
// Weight loader
using loader_b_t = typename metal::conditional_t<
// Check for small channel specialization
N_CHANNELS != 0 && N_CHANNELS <= 4,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T, BM, BN, BK, tgp_size, N_CHANNELS, tgp_padding_b>,
// Go to small channel specialization
Conv2DWeightBlockLoaderSmallChannels<
T,
BM,
BN,
BK,
tgp_size,
N_CHANNELS,
tgp_padding_b>,
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
// Else go to general loader
Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>
>;
using mma_t = BlockMMA<
T,
T,
@@ -99,12 +117,12 @@ template <typename T,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
@@ -123,8 +141,10 @@ template <typename T,
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
loader_a_t loader_a(
A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
loader_b_t loader_b(
B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
@@ -152,38 +172,53 @@ template <typename T,
short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col);
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, channel_name, n_channels, filter_name, small_filter) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name "_filter_" #filter_name)]] \
[[kernel]] void implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
#define instantiate_implicit_conv_2d( \
name, \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
channel_name, \
n_channels, \
filter_name, \
small_filter) \
template [[host_name("implicit_gemm_conv_2d_" #name "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn "_channel_" #channel_name \
"_filter_" #filter_name)]] [[kernel]] void \
implicit_gemm_conv_2d<itype, bm, bn, bk, wm, wn, n_channels, small_filter>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
// clang-format off
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, s, true) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, l, 0, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 1, 1, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 2, 2, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 3, 3, l, false) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false)
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn, 4, 4, l, false) // clang-format on
#define instantiate_implicit_2d_blocks(name, itype) \
// clang-format off
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -4,23 +4,25 @@
#include "mlx/backend/metal/kernels/steel/gemm/mma.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/conv.h"
#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
using namespace metal;
using namespace mlx::steel;
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void implicit_gemm_conv_2d_general(
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
typename AccumType = float,
typename Epilogue = TransformNone<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
implicit_gemm_conv_2d_general(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* C [[buffer(2)]],
@@ -33,9 +35,8 @@ template <typename T,
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
(void)lid;
constexpr bool transpose_a = false;
constexpr bool transpose_b = true;
constexpr short tgp_padding_a = 16 / sizeof(T);
@@ -49,15 +50,15 @@ template <typename T,
constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
constexpr short tgp_size = WM * WN * 32;
// Input loader
using loader_a_t = Conv2DInputBlockLoaderGeneral<
T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Input loader
using loader_a_t =
Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
// Weight loader
using loader_b_t = Conv2DWeightBlockLoaderGeneral<
T, BM, BN, BK, tgp_size, tgp_padding_b>;
using loader_b_t =
Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
using mma_t = BlockMMA<
T,
T,
@@ -70,12 +71,12 @@ template <typename T,
transpose_b,
shape_a_cols,
shape_b_cols>;
threadgroup T As[tgp_mem_size_a];
threadgroup T Bs[tgp_mem_size_b];
const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
const int tid_x = (tid.x) >> gemm_params->swizzle_log;
if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
@@ -103,13 +104,32 @@ template <typename T,
const int2 offsets_b(0, c_col);
// Prepare threadgroup loading operations
loader_a_t loader_a(A, As, offsets_a, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
loader_b_t loader_b(B, Bs, offsets_b, params, jump_params, base_wh, base_ww, simd_gid, simd_lid);
loader_a_t loader_a(
A,
As,
offsets_a,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
loader_b_t loader_b(
B,
Bs,
offsets_b,
params,
jump_params,
base_wh,
base_ww,
simd_gid,
simd_lid);
// Prepare threadgroup mma operation
mma_t mma_op(simd_gid, simd_lid);
int gemm_k_iterations = base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
int gemm_k_iterations =
base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -143,22 +163,24 @@ template <typename T,
STEEL_PRAGMA_UNROLL
for (int i = 0; i < mma_t::TM; i++) {
int cm = offset_m + i * mma_t::TM_stride;
int n = cm / jump_params->adj_out_hw;
int hw = cm % jump_params->adj_out_hw;
int oh = (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow = (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
int oh =
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
int ow =
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
if(n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] + oh * params->out_strides[1] + ow * params->out_strides[2];
if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
int offset_cm = n * params->out_strides[0] +
oh * params->out_strides[1] + ow * params->out_strides[2];
STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = mma_op.results[i * mma_t::TN + j].thread_elements();
thread const auto& accum =
mma_op.results[i * mma_t::TN + j].thread_elements();
int offset = offset_cm + (j * mma_t::TN_stride);
// Apply epilogue and output C
@@ -170,40 +192,42 @@ template <typename T,
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
}
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template [[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn) \
template \
[[host_name("implicit_gemm_conv_2d_general_" #name "_bm" #bm "_bn" #bn \
"_bk" #bk "_wm" #wm "_wn" #wn)]] [[kernel]] void \
implicit_gemm_conv_2d_general<itype, bm, bn, bk, wm, wn>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* C [[buffer(2)]], \
const constant MLXConvParams<2>* params [[buffer(3)]], \
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], \
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], \
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], \
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_implicit_2d_filter(name, itype, bm, bn, bk, wm, wn) \
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
instantiate_implicit_conv_2d(name, itype, bm, bn, bk, wm, wn)
#define instantiate_implicit_2d_blocks(name, itype) \
// clang-format off
#define instantiate_implicit_2d_blocks(name, itype) \
instantiate_implicit_2d_filter(name, itype, 32, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 64, 8, 16, 4, 1) \
instantiate_implicit_2d_filter(name, itype, 32, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 32, 64, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 32, 16, 2, 2) \
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2)
instantiate_implicit_2d_filter(name, itype, 64, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_implicit_2d_blocks(float32, float);
instantiate_implicit_2d_blocks(float16, half);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t);
instantiate_implicit_2d_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -164,10 +164,12 @@ struct GEMMKernel {
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col;
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);

View File

@@ -1,107 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
D += params->batch_stride_d * tid.z;
gemm_kernel::run(
A, B, D,
params,
As, Bs,
simd_lane_id, simd_group_id, tid, lid
);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@@ -1,276 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = float,
typename Epilogue = TransformAdd<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
const device T *C [[buffer(2)]],
device T *D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
using gemm_kernel =
GEMMKernel<T, T, BM, BN, BK, WM, WN,
transpose_a, transpose_b,
MN_aligned, K_aligned,
AccumType, Epilogue>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
ulong3 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, C_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
C += batch_offsets.z;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += addmm_params->batch_stride_c * tid.z;
}
D += params->batch_stride_d * tid.z;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col;
C += c_row * addmm_params->ldc + c_col * addmm_params->fdc;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
const Epilogue epilogue_op(addmm_params->alpha, addmm_params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, K_aligned>{});
mma_op.store_result(D, params->ldd, C, addmm_params->ldc, addmm_params->fdc, epilogue_op);
return;
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, addmm_params->ldc, addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* gemm_params [[buffer(4)]], \
const constant GEMMAddMMParams* params [[buffer(5)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@@ -0,0 +1,468 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool has_batch [[function_constant(10)]];
constant bool use_out_source [[function_constant(100)]];
constant bool do_axpby [[function_constant(110)]];
constant bool align_M [[function_constant(200)]];
constant bool align_N [[function_constant(201)]];
constant bool align_K [[function_constant(202)]];
constant bool do_gather [[function_constant(300)]];
constant bool gather_bias = do_gather && use_out_source;
// clang-format off
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
const device T* C [[buffer(2), function_constant(use_out_source)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
true,
true,
AccumType>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
// Find block
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
// Exit early if out of bounds
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Adjust for batch
// Handle gather
if (do_gather) {
// Read indices
uint32_t indx_A, indx_B, indx_C;
if (has_batch) {
const constant size_t* indx_A_bstrides = batch_strides;
const constant size_t* indx_B_bstrides =
batch_strides + params->batch_ndim;
ulong2 indx_offsets = elem_to_loc_broadcast(
tid.z,
batch_shape,
indx_A_bstrides,
indx_B_bstrides,
params->batch_ndim);
indx_A = lhs_indices[indx_offsets.x];
indx_B = rhs_indices[indx_offsets.y];
if (use_out_source) {
const constant size_t* indx_C_bstrides =
indx_B_bstrides + params->batch_ndim;
auto indx_offset_C = elem_to_loc(
tid.z, batch_shape, indx_C_bstrides, params->batch_ndim);
indx_C = C_indices[indx_offset_C];
}
} else {
indx_A = lhs_indices[params->batch_stride_a * tid.z];
indx_B = rhs_indices[params->batch_stride_b * tid.z];
if (use_out_source) {
indx_C = C_indices[addmm_params->batch_stride_c * tid.z];
}
}
// Translate indices to offsets
int batch_ndim_A = operand_batch_ndim.x;
const constant int* batch_shape_A = operand_shape;
const constant size_t* batch_strides_A = operand_strides;
A += elem_to_loc(indx_A, batch_shape_A, batch_strides_A, batch_ndim_A);
int batch_ndim_B = operand_batch_ndim.y;
const constant int* batch_shape_B = batch_shape_A + batch_ndim_A;
const constant size_t* batch_strides_B = batch_strides_A + batch_ndim_A;
B += elem_to_loc(indx_B, batch_shape_B, batch_strides_B, batch_ndim_B);
if (use_out_source) {
int batch_ndim_C = operand_batch_ndim.z;
const constant int* batch_shape_C = batch_shape_B + batch_ndim_B;
const constant size_t* batch_strides_C = batch_strides_B + batch_ndim_B;
C += elem_to_loc(indx_C, batch_shape_C, batch_strides_C, batch_ndim_C);
}
}
// Handle regular batch
else {
if (has_batch) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
if (use_out_source) {
const constant size_t* C_bstrides = B_bstrides + params->batch_ndim;
C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim);
}
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
if (use_out_source) {
C += addmm_params->batch_stride_c * tid.z;
}
}
}
D += params->batch_stride_d * tid.z;
// Prepare threadgroup memory
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
if (use_out_source) {
C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc;
}
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup bounds
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
// Prepare iterations
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Do unaligned K iterations first
if (!align_K) {
const int k_last = params->gemm_k_iterations_aligned * BK;
const int k_remain = params->K - k_last;
const size_t k_jump_a =
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
const size_t k_jump_b =
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
// Move loader source ahead to end
loader_a.src += k_jump_a;
loader_b.src += k_jump_b;
// Load tile
const short2 tile_dims_A =
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
const short2 tile_dims_B =
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do matmul
mma_op.mma(As, Bs);
// Reset source back to start
loader_a.src -= k_jump_a;
loader_b.src -= k_jump_b;
}
const TransformAdd<AccumType, AccumType> epilogue_op_add(
addmm_params->alpha, addmm_params->beta);
const TransformAxpby<AccumType, AccumType> epilogue_op_axpby(
addmm_params->alpha, addmm_params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (align_M && align_N) {
// Do gemm
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
const int leftover_bk = 0;
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
// Do gemm
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby);
} else {
mma_op.apply_epilogue(
C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result(D, params->ldd);
} else if (align_N || tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else if (align_M || tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
// Do epilogue
if (use_out_source) {
if (do_axpby) {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_axpby);
} else {
mma_op.apply_epilogue_safe(
C,
addmm_params->ldc,
addmm_params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op_add);
}
}
// Store results to device memory
return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
// clang-format off
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_fused_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, float>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2), function_constant(use_out_source)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], \
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], \
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], \
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], \
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], \
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]); // clang-format on
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) // clang-format on
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@@ -1,8 +1,8 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
using namespace mlx::steel;
@@ -11,319 +11,380 @@ using namespace mlx::steel;
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
bool has_operand_mask=false>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void block_masked_gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *D [[buffer(3)]],
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
bool has_operand_mask = false>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
block_masked_gemm(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device T* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device bool *out_mask [[buffer(10)]],
const device bool *lhs_mask [[buffer(11)]],
const device bool *rhs_mask [[buffer(12)]],
const device bool* out_mask [[buffer(10)]],
const device bool* lhs_mask [[buffer(11)]],
const device bool* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
uint3 lid [[thread_position_in_threadgroup]]) {
// Appease the compiler
(void)lid;
// Appease the compiler
(void)lid;
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
using gemm_kernel = GEMMKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if(params->batch_ndim > 1) {
const constant size_t* mask_batch_strides = batch_strides + 2 * params->batch_ndim;
out_mask += elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
if(has_operand_mask) {
const constant size_t* mask_strides_lhs = mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_rhs = mask_strides_lhs + params->batch_ndim;
if (params->batch_ndim > 1) {
const constant size_t* mask_batch_strides =
batch_strides + 2 * params->batch_ndim;
out_mask +=
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_lhs, mask_strides_rhs, params->batch_ndim);
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
} else {
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
if(has_operand_mask) {
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
}
}
// Adjust for batch
if(params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
if (has_operand_mask) {
const constant size_t* mask_strides_lhs =
mask_batch_strides + params->batch_ndim;
const constant size_t* mask_strides_rhs =
mask_strides_lhs + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
tid.z,
batch_shape,
mask_strides_lhs,
mask_strides_rhs,
params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
lhs_mask += batch_offsets.x;
rhs_mask += batch_offsets.y;
}
D += params->batch_stride_d * tid.z;
} else {
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
if (has_operand_mask) {
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
}
}
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* A_bstrides = batch_strides;
const constant size_t* B_bstrides = batch_strides + params->batch_ndim;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
D += c_row * params->ldd + c_col;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
A += batch_offsets.x;
B += batch_offsets.y;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
} else {
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
}
// Write zeros and return
if(!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
D += params->batch_stride_d * tid.z;
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
D += bi * params->ldd + bj;
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
// Write zeros and return
if (!mask_out) {
constexpr short tgp_size = WM * WN * 32;
constexpr short vec_size = 4;
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for(short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for(short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
// Tile threads in threadgroup
constexpr short TN = BN / vec_size;
constexpr short TM = tgp_size / TN;
const short thread_idx = simd_group_id * 32 + simd_lane_id;
const short bi = thread_idx / TN;
const short bj = vec_size * (thread_idx % TN);
D += bi * params->ldd + bj;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
for (short ti = 0; ti < BM; ti += TM) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
return;
} else {
short jmax = tgp_bn - bj;
jmax = jmax < vec_size ? jmax : vec_size;
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
for (short j = 0; j < jmax; j++) {
D[ti * params->ldd + j] = T(0.);
}
}
}
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(
A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(
B, params->ldb, Bs, simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Prepare threadgroup mma operation
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
// Loop tail
if (!K_aligned) {
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Prepare threadgroup loading operations
thread typename gemm_kernel::loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread typename gemm_kernel::loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
mma_op.mma(As, Bs);
}
}
// Load elements into threadgroup
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
bool M_aligned = (tgp_bm == BM);
bool N_aligned = (tgp_bn == BN);
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Store results to device memory
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!has_operand_mask ||
(lhs_mask
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask
[(params->K / BM) * mask_strides[5] +
tid_x * mask_strides[4]])) {
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
if (M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
bool M_aligned = (tgp_bm == BM);
bool N_aligned = (tgp_bn == BN);
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
rhs_mask[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if(!has_operand_mask ||
(lhs_mask[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
rhs_mask[(params->K / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
if(M_aligned && N_aligned) {
mma_op.store_result(D, params->ldd);
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
} else {
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, omname, op_mask) \
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_op_mask_" #omname)]] \
[[kernel]] void block_masked_gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, op_mask>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const device bool *out_mask [[buffer(10)]], \
const device bool *lhs_mask [[buffer(11)]], \
const device bool *rhs_mask [[buffer(12)]], \
const constant int* mask_strides [[buffer(13)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned, \
omname, \
op_mask) \
template [[host_name("steel_block_masked_gemm_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname \
"_op_mask_" #omname)]] [[kernel]] void \
block_masked_gemm< \
itype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned, \
op_mask>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device itype* D [[buffer(3)]], \
const constant GEMMParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
const device bool* out_mask [[buffer(10)]], \
const device bool* lhs_mask [[buffer(11)]], \
const device bool* rhs_mask [[buffer(12)]], \
const constant int* mask_strides [[buffer(13)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
// clang-format off
#define instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true)
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, N, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, T, true) // clang-format on
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
instantiate_gemm_mask_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on

View File

@@ -10,81 +10,99 @@ using namespace mlx::steel;
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device U *C [[buffer(2)]],
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
uint3 lid [[thread_position_in_threadgroup]]) {
(void)lid;
(void)lid;
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
using gemm_kernel = GEMMKernel<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
MN_aligned,
K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
const int tid_x = tid.x;
const int tid_y = tid.y;
const int tid_z = tid.z;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
const int tid_x = tid.x;
const int tid_y = tid.y;
const int tid_z = tid.z;
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
const size_t k_start_long = size_t(k_start);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
A += transpose_a ? (c_row_long + k_start_long * params->lda)
: (k_start_long + c_row_long * params->lda);
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
: (c_col_long + k_start_long * params->ldb);
C += (size_t(params->split_k_partition_stride) * tid_z) +
(c_row_long * params->ldc + c_col_long);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K % BK;
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
int gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K % BK;
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
@@ -95,37 +113,38 @@ template <typename T,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
}
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
if(!K_aligned || gemm_k_iter_remaining > 0)
if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining =
(params->K - (k_start + params->split_k_partition_size)) / BK;
if (!K_aligned || gemm_k_iter_remaining > 0)
gemm_kernel::gemm_loop(
As,
Bs,
@@ -137,144 +156,178 @@ template <typename T,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
}
}
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
mma_op.store_result(C, params->ldc);
} else {
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
}
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
mma_op.store_result(C, params->ldc);
} else {
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device otype *C [[buffer(2)]], \
const constant GEMMSpiltKParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
#define instantiate_gemm( \
tname, \
trans_a, \
trans_b, \
iname, \
itype, \
oname, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
aname, \
mn_aligned, \
kname, \
k_aligned) \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname \
"_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn \
"_MN_" #aname "_K_" #kname)]] [[kernel]] void \
gemm_splitk< \
itype, \
otype, \
bm, \
bn, \
bk, \
wm, \
wn, \
trans_a, \
trans_b, \
mn_aligned, \
k_aligned>( \
const device itype* A [[buffer(0)]], \
const device itype* B [[buffer(1)]], \
device otype* C [[buffer(2)]], \
const constant GEMMSpiltKParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
// clang-format off
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) // clang-format on
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
// clang-format off
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) // clang-format on
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
// clang-format off
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) // clang-format on
// clang-format off
instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float); // clang-format on
///////////////////////////////////////////////////////////////////////////////
// Split k accumulation kernel
// Split k accumulation kernel
///////////////////////////////////////////////////////////////////////////////
template <typename AccT,
typename OutT,
typename Epilogue = TransformNone<OutT, AccT>>
template <
typename AccT,
typename OutT,
typename Epilogue = TransformNone<OutT, AccT>>
[[kernel]] void gemm_splitk_accum(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const device AccT* C_split [[buffer(0)]],
device OutT* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
D += gid.x + gid.y * size_t(ldd);
C_split += gid.x + gid.y * size_t(ldd);
int offset = 0;
size_t offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
for (int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
// Write output
D[0] = Epilogue::apply(out);
}
template <typename AccT,
typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>>
template <
typename AccT,
typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>>
[[kernel]] void gemm_splitk_accum_axpby(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const device AccT* C_split [[buffer(0)]],
device OutT* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device OutT *C [[buffer(5)]],
const device OutT* C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
C += gid.x * fdc + gid.y * ldc;
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
D += gid.x + gid.y * size_t(ldd);
C_split += gid.x + gid.y * size_t(ldd);
int offset = 0;
size_t offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
for (int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
// Write output
Epilogue op(alpha, beta);
D[0] = op.apply(out, *C);
}
#define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
[[kernel]] void gemm_splitk_accum<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
const device otype *C [[buffer(5)]], \
const constant int& ldc [[buffer(6)]], \
const constant int& fdc [[buffer(7)]], \
const constant float& alpha [[buffer(8)]], \
const constant float& beta [[buffer(9)]], \
#define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname \
"_" #aname)]] [[kernel]] void \
gemm_splitk_accum<atype, otype>( \
const device atype* C_split [[buffer(0)]], \
device otype* D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname \
"_axpby")]] [[kernel]] void \
gemm_splitk_accum_axpby<atype, otype>( \
const device atype* C_split [[buffer(0)]], \
device otype* D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
const device otype* C [[buffer(5)]], \
const constant int& ldc [[buffer(6)]], \
const constant int& fdc [[buffer(7)]], \
const constant float& alpha [[buffer(8)]], \
const constant float& beta [[buffer(9)]], \
uint2 gid [[thread_position_in_grid]]);
// clang-format off
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float);
instantiate_accum(float32, float, float32, float); // clang-format on

View File

@@ -198,6 +198,73 @@ struct BlockMMA {
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue(
const device U* C,
const int ldc,
const int fdc,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Apply epilogue
accum[0] = epilogue_op.apply(accum[0], C[offset_c]);
accum[1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
/* Apply epilogue */
template <typename BinaryEpilogue>
METAL_FUNC void apply_epilogue_safe(
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const BinaryEpilogue& epilogue_op) {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
dst_tile_dims -= short2(tn + sn, sm + tm);
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
// Read C
U c_elems[2] = {0};
if ((j * TN_stride + 1) < dst_tile_dims.x) {
c_elems[0] = C[offset_c];
c_elems[1] = C[offset_c + fdc];
} else if ((j * TN_stride) < dst_tile_dims.x) {
c_elems[0] = C[offset_c];
}
// Apply epilogue
accum[0] = epilogue_op.apply(accum[0], c_elems[0]);
accum[1] = epilogue_op.apply(accum[1], c_elems[1]);
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,

View File

@@ -21,9 +21,9 @@ struct GEMMParams {
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_d;
const size_t batch_stride_a;
const size_t batch_stride_b;
const size_t batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
@@ -54,7 +54,7 @@ struct GEMMAddMMParams {
const int ldc;
const int fdc;
const int batch_stride_c;
const size_t batch_stride_c;
const float alpha;
const float beta;

View File

@@ -26,6 +26,10 @@ template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
@@ -39,6 +43,10 @@ struct TransformAxpby {
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}

View File

@@ -3,9 +3,9 @@
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/ternary.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename Op>
[[kernel]] void ternary_op_v(
@@ -65,7 +65,8 @@ template <typename T, typename Op>
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
auto c_idx = elem_to_loc_3(index, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
}
@@ -81,8 +82,10 @@ template <typename T, typename Op, int DIM>
constant const size_t c_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto idx =
elem_to_loc_3_nd<DIM>(index, shape, a_strides, b_strides, c_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
@@ -99,103 +102,104 @@ template <typename T, typename Op>
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
auto idx =
elem_to_loc_3_nd(index, shape, a_strides, b_strides, c_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
d[out_idx] = Op()(a[idx.x], b[idx.y], c[idx.z]);
}
#define instantiate_ternary_v(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void ternary_op_v<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
uint index [[thread_position_in_grid]]); \
#define instantiate_ternary_v(name, type, op) \
template [[host_name(name)]] [[kernel]] void ternary_op_v<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_ternary_g(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void ternary_op_g<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const size_t* c_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
#define instantiate_ternary_g(name, type, op) \
template [[host_name(name)]] [[kernel]] void ternary_op_g<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const size_t* c_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_ternary_g_dim(name, type, op, dims) \
template [[host_name(name "_" #dims)]] \
[[kernel]] void ternary_op_g_nd<type, op, dims>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
constant const size_t c_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
#define instantiate_ternary_g_dim(name, type, op, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \
ternary_op_g_nd<type, op, dims>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
constant const size_t c_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_ternary_g_nd(name, type, op) \
template [[host_name(name "_1")]] \
[[kernel]] void ternary_op_g_nd1<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t& a_strides, \
constant const size_t& b_strides, \
constant const size_t& c_strides, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] \
[[kernel]] void ternary_op_g_nd2<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
constant const size_t c_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] \
[[kernel]] void ternary_op_g_nd3<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
constant const size_t c_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_ternary_g_dim(name, type, op, 4) \
instantiate_ternary_g_dim(name, type, op, 5) \
#define instantiate_ternary_g_nd(name, type, op) \
template [[host_name(name "_1")]] [[kernel]] void \
ternary_op_g_nd1<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t& a_strides, \
constant const size_t& b_strides, \
constant const size_t& c_strides, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
ternary_op_g_nd2<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
constant const size_t c_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
ternary_op_g_nd3<type, op>( \
device const bool* a, \
device const type* b, \
device const type* c, \
device type* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
constant const size_t c_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_ternary_g_dim(name, type, op, 4) \
instantiate_ternary_g_dim(name, type, op, 5)
// clang-format off
#define instantiate_ternary_all(name, tname, type, op) \
instantiate_ternary_v("v" #name #tname, type, op) \
instantiate_ternary_g("g" #name #tname, type, op) \
instantiate_ternary_g_nd("g" #name #tname, type, op) \
instantiate_ternary_v("v" #name #tname, type, op) \
instantiate_ternary_g("g" #name #tname, type, op) \
instantiate_ternary_g_nd("g" #name #tname, type, op) // clang-format on
#define instantiate_ternary_types(name, op) \
instantiate_ternary_all(name, bool_, bool, op) \
instantiate_ternary_all(name, uint8, uint8_t, op) \
instantiate_ternary_all(name, uint16, uint16_t, op) \
instantiate_ternary_all(name, uint32, uint32_t, op) \
instantiate_ternary_all(name, uint64, uint64_t, op) \
instantiate_ternary_all(name, int8, int8_t, op) \
instantiate_ternary_all(name, int16, int16_t, op) \
instantiate_ternary_all(name, int32, int32_t, op) \
instantiate_ternary_all(name, int64, int64_t, op) \
instantiate_ternary_all(name, float16, half, op) \
instantiate_ternary_all(name, float32, float, op) \
// clang-format off
#define instantiate_ternary_types(name, op) \
instantiate_ternary_all(name, bool_, bool, op) \
instantiate_ternary_all(name, uint8, uint8_t, op) \
instantiate_ternary_all(name, uint16, uint16_t, op) \
instantiate_ternary_all(name, uint32, uint32_t, op) \
instantiate_ternary_all(name, uint64, uint64_t, op) \
instantiate_ternary_all(name, int8, int8_t, op) \
instantiate_ternary_all(name, int16, int16_t, op) \
instantiate_ternary_all(name, int32, int32_t, op) \
instantiate_ternary_all(name, int64, int64_t, op) \
instantiate_ternary_all(name, float16, half, op) \
instantiate_ternary_all(name, float32, float, op) \
instantiate_ternary_all(name, bfloat16, bfloat16_t, op) \
instantiate_ternary_all(name, complex64, complex64_t, op) \
instantiate_ternary_all(name, complex64, complex64_t, op) // clang-format on
instantiate_ternary_types(select, Select)
instantiate_ternary_types(select, Select)

View File

@@ -158,6 +158,12 @@ struct Cosh {
};
};
struct Conjugate {
complex64_t operator()(complex64_t x) {
return complex64_t{x.real, -x.imag};
}
};
struct Erf {
template <typename T>
T operator()(T x) {

View File

@@ -22,44 +22,46 @@ template <typename T, typename Op>
out[index] = Op()(in[idx]);
}
#define instantiate_unary_v(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void unary_op_v<type, op>( \
device const type* in, \
device type* out, \
#define instantiate_unary_v(name, type, op) \
template [[host_name(name)]] [[kernel]] void unary_op_v<type, op>( \
device const type* in, \
device type* out, \
uint index [[thread_position_in_grid]]);
#define instantiate_unary_g(name, type, op) \
template [[host_name(name)]] \
[[kernel]] void unary_op_g<type, op>( \
device const type* in, \
device type* out, \
device const int* in_shape, \
device const size_t* in_strides, \
device const int& ndim, \
#define instantiate_unary_g(name, type, op) \
template [[host_name(name)]] [[kernel]] void unary_op_g<type, op>( \
device const type* in, \
device type* out, \
device const int* in_shape, \
device const size_t* in_strides, \
device const int& ndim, \
uint index [[thread_position_in_grid]]);
// clang-format off
#define instantiate_unary_all(name, tname, type, op) \
instantiate_unary_v("v" #name #tname, type, op) \
instantiate_unary_g("g" #name #tname, type, op)
instantiate_unary_v("v" #name #tname, type, op) \
instantiate_unary_g("g" #name #tname, type, op) // clang-format on
#define instantiate_unary_float(name, op) \
instantiate_unary_all(name, float16, half, op) \
instantiate_unary_all(name, float32, float, op) \
instantiate_unary_all(name, bfloat16, bfloat16_t, op) \
// clang-format off
#define instantiate_unary_float(name, op) \
instantiate_unary_all(name, float16, half, op) \
instantiate_unary_all(name, float32, float, op) \
instantiate_unary_all(name, bfloat16, bfloat16_t, op) // clang-format on
#define instantiate_unary_types(name, op) \
instantiate_unary_all(name, bool_, bool, op) \
instantiate_unary_all(name, uint8, uint8_t, op) \
// clang-format off
#define instantiate_unary_types(name, op) \
instantiate_unary_all(name, bool_, bool, op) \
instantiate_unary_all(name, uint8, uint8_t, op) \
instantiate_unary_all(name, uint16, uint16_t, op) \
instantiate_unary_all(name, uint32, uint32_t, op) \
instantiate_unary_all(name, uint64, uint64_t, op) \
instantiate_unary_all(name, int8, int8_t, op) \
instantiate_unary_all(name, int16, int16_t, op) \
instantiate_unary_all(name, int32, int32_t, op) \
instantiate_unary_all(name, int64, int64_t, op) \
instantiate_unary_float(name, op)
instantiate_unary_all(name, int8, int8_t, op) \
instantiate_unary_all(name, int16, int16_t, op) \
instantiate_unary_all(name, int32, int32_t, op) \
instantiate_unary_all(name, int64, int64_t, op) \
instantiate_unary_float(name, op) // clang-format on
// clang-format off
instantiate_unary_types(abs, Abs)
instantiate_unary_float(arccos, ArcCos)
instantiate_unary_float(arccosh, ArcCosh)
@@ -92,6 +94,7 @@ instantiate_unary_float(tanh, Tanh)
instantiate_unary_float(round, Round)
instantiate_unary_all(abs, complex64, complex64_t, Abs)
instantiate_unary_all(conj, complex64, complex64_t, Conjugate)
instantiate_unary_all(cos, complex64, complex64_t, Cos)
instantiate_unary_all(cosh, complex64, complex64_t, Cosh)
instantiate_unary_all(exp, complex64, complex64_t, Exp)
@@ -102,4 +105,4 @@ instantiate_unary_all(tan, complex64, complex64_t, Tan)
instantiate_unary_all(tanh, complex64, complex64_t, Tanh)
instantiate_unary_all(round, complex64, complex64_t, Round)
instantiate_unary_all(lnot, bool_, bool, LogicalNot)
instantiate_unary_all(lnot, bool_, bool, LogicalNot) // clang-format on

View File

@@ -9,8 +9,9 @@
OUTPUT_FILE=$1
CC=$2
SRCDIR=$3
CFLAGS=$4
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h 2>/dev/null)
CONTENT=$($CC -I $SRCDIR -E $SRCDIR/mlx/backend/metal/kernels/compiled_preamble.h $CFLAGS 2>/dev/null)
cat << EOF > "$OUTPUT_FILE"
// Copyright © 2023-24 Apple Inc.

View File

@@ -260,6 +260,138 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
// Steel matmul fallback
///////////////////////////////////////////////////////////////////////////////
void steel_matmul_conv_groups(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
int groups,
std::vector<array>& copies) {
using namespace mlx::steel;
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = false;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ ldd,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ size_t(K),
/* const size_t batch_stride_b = */ size_t(N) * K,
/* const size_t batch_stride_d = */ size_t(N),
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ 1};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, groups);
std::vector<int> batch_shape = {1};
std::vector<size_t> batch_strides = {0};
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
void steel_matmul(
const Stream& s,
metal::Device& d,
@@ -301,7 +433,7 @@ void steel_matmul(
}
}
int matrix_stride_out = M * N;
size_t matrix_stride_out = size_t(M) * N;
/////////////////////////////////////////////////////////////////////////////
// Split K specialization
@@ -364,7 +496,7 @@ void steel_matmul(
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
@@ -389,7 +521,7 @@ void steel_matmul(
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -417,16 +549,46 @@ void steel_matmul(
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_" << (transpose_a ? 't' : 'n')
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned";
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = (batch_shape.size() > 1);
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
@@ -446,9 +608,9 @@ void steel_matmul(
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_d = */ matrix_stride_out,
/* const size_t batch_stride_a = */ A_batch_stride.back(),
/* const size_t batch_stride_b = */ B_batch_stride.back(),
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -472,12 +634,10 @@ void steel_matmul(
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -541,7 +701,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
auto batch_size_out = out.size() / (M * N);
auto batch_size_out = out.size() / (size_t(M) * size_t(N));
// Collapse batches into M if needed
if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
@@ -644,7 +804,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -749,7 +909,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
collapse_batches(a, b, c);
auto batch_size_out = out.size() / (M * N);
size_t matrix_stride_out = size_t(M) * size_t(N);
auto batch_size_out = out.size() / (matrix_stride_out);
// Collapse batches into M if needed
if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
@@ -765,8 +926,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_shape = {1};
}
int matrix_stride_out = M * N;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
@@ -854,18 +1013,15 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&beta_, sizeof(float), 8);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
compute_encoder->setBytes(batch_shape.data(), batch_ndim * sizeof(int), 10);
compute_encoder->setBytes(
batch_strides_vec.data(), batch_ndim * sizeof(size_t), 11);
compute_encoder->setBytes(
batch_strides_mat.data(), batch_ndim * sizeof(size_t), 12);
compute_encoder->setBytes(
C_batch_stride.data(), batch_ndim * sizeof(size_t), 13);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides_vec, 11);
set_vector_bytes(compute_encoder, batch_strides_mat, 12);
set_vector_bytes(compute_encoder, C_batch_stride, 13);
int bias_stride = c.strides()[c.ndim() - 1];
compute_encoder->setBytes(&bias_stride, sizeof(int), 14);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -935,7 +1091,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(C_split, 2);
compute_encoder->setBytes(&params, sizeof(GEMMSpiltKParams), 3);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Do accum kernel
{
@@ -960,7 +1116,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size grid_dims = MTL::Size(N, M, 1);
MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
@@ -986,18 +1142,48 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
// Prepare kernel name
std::ostringstream kname;
kname << "steel_addmm_" << (transpose_a ? 't' : 'n')
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn << "_MN_"
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
<< ((K % bk == 0) ? "t" : "n") << "aligned"
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = (batch_shape.size() > 1);
const bool use_out_source = true;
const bool do_axpby = !(alpha_ == 1. && beta_ == 1.);
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = false;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder->setComputePipelineState(kernel);
int tn = (N + bn - 1) / bn;
@@ -1016,9 +1202,9 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ int(A_batch_stride.back()),
/* const int batch_stride_b = */ int(B_batch_stride.back()),
/* const int batch_stride_d = */ matrix_stride_out,
/* const size_t batch_stride_a = */ A_batch_stride.back(),
/* const size_t batch_stride_b = */ B_batch_stride.back(),
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -1026,7 +1212,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
GEMMAddMMParams params{
/* const int ldc = */ ldc,
/* const int fdc = */ fdc,
/* const int batch_stride_c = */ int(C_batch_stride.back()),
/* const size_t batch_stride_c = */ C_batch_stride.back(),
/* const float alpha = */ alpha_,
/* const float beta = */ beta_};
@@ -1052,12 +1238,10 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&gemm_params, sizeof(GEMMParams), 4);
compute_encoder->setBytes(&params, sizeof(GEMMAddMMParams), 5);
compute_encoder->setBytes(
batch_shape.data(), sizeof(int) * batch_shape.size(), 6);
compute_encoder->setBytes(
batch_strides.data(), sizeof(size_t) * batch_strides.size(), 7);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
@@ -1126,8 +1310,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& out_mask = inputs[2];
std::vector<int> batch_shape{1};
int A_batch_str = 0;
int B_batch_str = 0;
size_t A_batch_str = 0;
size_t B_batch_str = 0;
std::vector<size_t> batch_strides;
@@ -1145,8 +1329,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
batch_shape = bshape_c;
A_batch_str = int(bstrides_c[0].back());
B_batch_str = int(bstrides_c[1].back());
A_batch_str = bstrides_c[0].back();
B_batch_str = bstrides_c[1].back();
for (auto& bstr : bstrides_c) {
batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
@@ -1155,8 +1339,8 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
batch_strides = std::vector<size_t>(inputs.size(), 0);
}
auto batch_size_out = out.size() / (M * N);
int matrix_stride_out = M * N;
size_t matrix_stride_out = size_t(M) * N;
size_t batch_size_out = out.size() / (matrix_stride_out);
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
@@ -1197,9 +1381,9 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const int batch_stride_a = */ A_batch_str,
/* const int batch_stride_b = */ B_batch_str,
/* const int batch_stride_d = */ matrix_stride_out,
/* const size_t batch_stride_a = */ A_batch_str,
/* const size_t batch_stride_b = */ B_batch_str,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ int(batch_shape.size())};
@@ -1243,7 +1427,355 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_input_array(out_mask, 10);
set_vector_bytes(compute_encoder, mask_strides, 13);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
void BlockSparseMM::eval_gpu(const std::vector<array>& inputs, array& out) {
using namespace mlx::steel;
// assert(inputs.size() == 2);
if (!issubdtype(out.dtype(), floating)) {
throw std::runtime_error(
"[matmul] Does not yet support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
// Return 0s if either input is empty
if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype());
copy_gpu(zero, out, CopyType::Scalar, s);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return;
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
/////////////////////////////////////////////////////////////////////////////
// Init checks and prep
// Keep a vector with copies to be cleared in the completed buffer to release
// the arrays
std::vector<array> copies;
auto check_transpose = [&copies, &s](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (sty == 1) {
return std::make_tuple(false, stx, arr);
} else if (stx == 1) {
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto [transpose_a, a_cols, a] = check_transpose(a_pre);
auto [transpose_b, b_cols, b] = check_transpose(b_pre);
int lda = a_cols;
int ldb = b_cols;
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
/////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions
auto get_batch_dims = [](const auto& v) {
return decltype(v){v.begin(), v.end() - 2};
};
auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape());
std::vector<size_t> batch_strides;
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
size_t lhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
batch_strides.insert(
batch_strides.end(),
rhs_indices.strides().begin(),
rhs_indices.strides().end());
size_t rhs_indices_str = batch_strides.empty() ? 0 : batch_strides.back();
int batch_ndim = batch_shape.size();
if (batch_ndim == 0) {
batch_shape = {1};
batch_strides = {0};
}
int batch_ndim_A = a.ndim() - 2;
int batch_ndim_B = b.ndim() - 2;
std::vector<int> operand_batch_ndim = {batch_ndim_A, batch_ndim_B};
std::vector<int> batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides());
if (batch_ndim_A == 0) {
batch_shape_A = {1};
batch_strides_A = {0};
}
if (batch_ndim_B == 0) {
batch_shape_B = {1};
batch_strides_B = {0};
}
size_t matrix_stride_out = size_t(M) * N;
auto batch_size_out = out.size() / matrix_stride_out;
/////////////////////////////////////////////////////////////////////////////
// Gemv specialization
// Route to gemv if needed
if (std::min(M, N) == 1) {
// Collect problem info
bool is_b_matrix = N != 1;
auto& mat = is_b_matrix ? b : a;
auto& vec = is_b_matrix ? a : b;
bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
int in_vector_len = K;
int out_vector_len = is_b_matrix ? N : M;
int mat_cols = transpose_mat ? out_vector_len : in_vector_len;
int mat_rows = transpose_mat ? in_vector_len : out_vector_len;
int mat_ld = is_b_matrix ? b_cols : a_cols;
auto batch_strides_mat = is_b_matrix ? batch_strides_B : batch_strides_A;
auto batch_strides_vec = is_b_matrix ? batch_strides_A : batch_strides_B;
auto batch_shape_mat = is_b_matrix ? batch_shape_B : batch_shape_A;
auto batch_shape_vec = is_b_matrix ? batch_shape_A : batch_shape_B;
if (!is_b_matrix) {
batch_strides = rhs_indices.strides();
batch_strides.insert(
batch_strides.end(),
lhs_indices.strides().begin(),
lhs_indices.strides().end());
}
int batch_ndim = batch_shape.size();
// Determine dispatch kernel
int tm = 4, tn = 4;
int bm, bn, n_out_per_tgp;
std::ostringstream kname;
if (transpose_mat) {
bm = 8;
bn = 8;
if (out_vector_len >= 24576) {
bn = 128;
} else if (out_vector_len >= 16384) {
bn = 64;
} else if (out_vector_len >= 8192) {
bn = 16;
}
// Specialized kernel for very small outputs
tn = out_vector_len < tn ? 1 : tn;
n_out_per_tgp = bn * tn;
kname << "gemv_t_bs_" << type_to_name(out);
} else {
bm = out_vector_len >= 4096 ? 8 : 4;
bn = 32;
// Specialized kernel for very small outputs
tm = out_vector_len < tm ? 1 : tm;
n_out_per_tgp = bm * tm;
kname << "gemv_bs_" << type_to_name(out);
}
kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn;
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
MTL::Size group_dims = MTL::Size(bn, bm, 1);
MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
compute_encoder.set_input_array(mat, 0);
compute_encoder.set_input_array(vec, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&in_vector_len, sizeof(int), 4);
compute_encoder->setBytes(&out_vector_len, sizeof(int), 5);
compute_encoder->setBytes(&mat_ld, sizeof(int), 6);
compute_encoder->setBytes(&batch_ndim, sizeof(int), 9);
set_vector_bytes(compute_encoder, batch_shape, 10);
set_vector_bytes(compute_encoder, batch_strides, 11);
int batch_ndim_vec = batch_shape_vec.size();
compute_encoder->setBytes(&batch_ndim_vec, sizeof(int), 12);
set_vector_bytes(compute_encoder, batch_shape_vec, 13);
set_vector_bytes(compute_encoder, batch_strides_vec, 14);
int batch_ndim_mat = batch_shape_mat.size();
compute_encoder->setBytes(&batch_ndim_mat, sizeof(int), 15);
set_vector_bytes(compute_encoder, batch_shape_mat, 16);
set_vector_bytes(compute_encoder, batch_strides_mat, 17);
compute_encoder.set_input_array(lhs_indices, 18 + int(!is_b_matrix));
compute_encoder.set_input_array(rhs_indices, 18 + int(is_b_matrix));
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
/////////////////////////////////////////////////////////////////////////////
// Regular kernel dispatch
// Determine dispatch kernel
int bm = 32, bn = 32, bk = 16;
int wm = 2, wn = 2;
if ((size_t)batch_size_out * M * N >= 1ul << 20) {
if (!transpose_a && transpose_b) {
bm = 64;
bn = (out.dtype() == float32) ? 64 : 32;
bk = (out.dtype() == float32) ? 16 : 32;
} else {
bm = 64;
bn = 64;
}
}
// Prepare kernel name
std::ostringstream kname;
kname << "steel_gemm_fused_" << (transpose_a ? 't' : 'n')
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
<< "_wm" << wm << "_wn" << wn;
std::string base_name = kname.str();
const bool has_batch = batch_ndim > 1;
const bool use_out_source = false;
const bool do_axpby = false;
const bool align_M = (M % bm) == 0;
const bool align_N = (N % bn) == 0;
const bool align_K = (K % bk) == 0;
const bool do_gather = true;
metal::MTLFCList func_consts = {
{&has_batch, MTL::DataType::DataTypeBool, 10},
{&use_out_source, MTL::DataType::DataTypeBool, 100},
{&do_axpby, MTL::DataType::DataTypeBool, 110},
{&align_M, MTL::DataType::DataTypeBool, 200},
{&align_N, MTL::DataType::DataTypeBool, 201},
{&align_K, MTL::DataType::DataTypeBool, 202},
{&do_gather, MTL::DataType::DataTypeBool, 300},
};
// clang-format off
kname << "_has_batch_" << (has_batch ? 't' : 'n')
<< "_use_out_source_" << (use_out_source ? 't' : 'n')
<< "_do_axpby_" << (do_axpby ? 't' : 'n')
<< "_align_M_" << (align_M ? 't' : 'n')
<< "_align_N_" << (align_N ? 't' : 'n')
<< "_align_K_" << (align_K ? 't' : 'n')
<< "_do_gather_" << (do_gather ? 't' : 'n'); // clang-format on
std::string hash_name = kname.str();
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
compute_encoder->setComputePipelineState(kernel);
// Use problem size to determine threadblock swizzle
int tn = (N + bn - 1) / bn;
int tm = (M + bm - 1) / bm;
// TODO: Explore device-based tuning for swizzle
int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
// Prepare steel matmul params
GEMMParams params{
/* const int M = */ M,
/* const int N = */ N,
/* const int K = */ K,
/* const int lda = */ lda,
/* const int ldb = */ ldb,
/* const int ldd = */ N,
/* const int tiles_n = */ tn,
/* const int tiles_m = */ tm,
/* const size_t batch_stride_a = */ lhs_indices_str,
/* const size_t batch_stride_b = */ rhs_indices_str,
/* const size_t batch_stride_d = */ matrix_stride_out,
/* const int swizzle_log = */ swizzle_log,
/* const int gemm_k_iterations_aligned = */ (K / bk),
/* const int batch_ndim = */ batch_ndim};
// Prepare launch grid params
int tile = 1 << swizzle_log;
tm = (tm + tile - 1) / tile;
tn = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
// Launch kernel
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_output_array(out, 3);
compute_encoder->setBytes(&params, sizeof(GEMMParams), 4);
set_vector_bytes(compute_encoder, batch_shape, 6);
set_vector_bytes(compute_encoder, batch_strides, 7);
compute_encoder.set_input_array(lhs_indices, 10);
compute_encoder.set_input_array(rhs_indices, 11);
std::vector operand_shape = batch_shape_A;
operand_shape.insert(
operand_shape.end(), batch_shape_B.begin(), batch_shape_B.end());
std::vector operand_strides = batch_strides_A;
operand_strides.insert(
operand_strides.end(), batch_strides_B.begin(), batch_strides_B.end());
operand_batch_ndim.push_back(0);
set_vector_bytes(compute_encoder, operand_shape, 13);
set_vector_bytes(compute_encoder, operand_strides, 14);
set_vector_bytes(compute_encoder, operand_batch_ndim, 15);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(

View File

@@ -12,6 +12,23 @@
namespace mlx::core {
void steel_matmul_conv_groups(
const Stream& s,
metal::Device& d,
const array& a,
const array& b,
array& out,
int M,
int N,
int K,
int lda,
int ldb,
int ldd,
bool transpose_a,
bool transpose_b,
int groups,
std::vector<array>& copies);
void steel_matmul(
const Stream& s,
metal::Device& d,

View File

@@ -27,24 +27,6 @@ int max_ops_per_buffer() {
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
MTL::CommandBuffer* increment_command_buffer(Stream s) {
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
if (command_buffer == nullptr ||
d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
if (command_buffer != nullptr) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); });
d.commit_command_buffer(s.index);
}
command_buffer = d.new_command_buffer(s.index);
}
d.increment_command_buffer_ops(s.index);
return command_buffer;
}
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
@@ -58,7 +40,10 @@ std::function<void()> make_task(array arr, bool signal) {
auto task = [arr = std::move(arr), signal]() mutable {
auto pool = new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
d.increment_command_buffer_ops(s.index);
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
@@ -91,11 +76,13 @@ std::function<void()> make_task(array arr, bool signal) {
arr.detach();
}
if (signal) {
metal::device(s.device).end_encoding(s.index);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(arr.event().raw_event().get()),
arr.event().value());
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
d.end_encoding(s.index);
if (signal) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(arr.event().raw_event().get()),
arr.event().value());
}
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers), event = arr.event()](
@@ -103,7 +90,8 @@ std::function<void()> make_task(array arr, bool signal) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
metal::device(s.device).commit_command_buffer(s.index);
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
@@ -120,14 +108,12 @@ std::function<void()> make_synchronize_task(
return [s, p = std::move(p)]() {
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
if (cb == nullptr) {
cb = d.new_command_buffer(s.index);
} else {
d.end_encoding(s.index);
}
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
p->set_value();
};
}

Some files were not shown because too many files have changed in this diff Show More