Compare commits

..

41 Commits

Author SHA1 Message Date
Awni Hannun
cf236fc390 version (#1191) 2024-06-06 17:16:40 -07:00
Alex Barron
27d70c7d9d Feature complete Metal FFT (#1102)
* feature complete metal fft

* fix contiguity bug

* jit fft

* simplify rader/bluestein constant computation

* remove kernel/utils.h dep

* remove bf16.h dep

* format

---------

Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-06 12:57:25 -07:00
nicolov
0e585b4409 Add docstring for scatter (#1189)
* Add docstring for scatter

* docs nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-06-06 11:51:25 -07:00
Angelos Katharopoulos
0163a8e57a Add docs for the distributed namespace (#1184) 2024-06-06 11:37:00 -07:00
Awni Hannun
578842954c fix jit scan when output doesn't have primitive (#1190) 2024-06-06 07:24:58 -07:00
Awni Hannun
496315fe1d Fix scan (#1188)
* fix scan

* improve grid size

* fix cpu cummax
2024-06-05 14:21:58 -07:00
Angelos Katharopoulos
0fe6895893 Fix the hard-shrink test (#1185) 2024-06-04 16:22:56 -07:00
Nikhil Mehta
0b7d71fd2f Add softmin, hardshrink, hardtanh (#1180)
---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
2024-06-04 15:48:18 -07:00
Awni Hannun
83b11bc58d Fix Metal API validation for empty concat (#1183) 2024-06-04 13:17:08 -07:00
Alex Barron
375a8bbdcc Add some internal GPU apis (#1177)
* Add unary/binary/ternay/slice/concat internal GPU ops

* add pad internal op

* formatting + no_cpu fix
2024-06-04 09:24:26 -07:00
Awni Hannun
ea9090bbc4 Add view op (#1179)
* add view primitive

* nit

* fix view
2024-06-04 08:05:27 -07:00
nicolov
81def6ac76 Fix benchmark (#1175) 2024-06-04 07:50:46 -07:00
Angelos Katharopoulos
3de8ce3f3c In place all-reduce and forgiving init (#1178) 2024-06-03 16:47:47 -07:00
Alex Barron
4d485fca24 Add defines include (#1176)
Co-authored-by: Alex Barron <abarron22@apple.com>
2024-06-03 09:50:10 -07:00
Brian Keene
1865299a30 Metal shaders for memory efficient self attention on large sequences (#964)
* Metal shaders for efficient self attention on large sequences

Updated fast attention: GEMM-ified with Steel primitives
Uses flash attention 1 for scale correction

* more compiler silencing

* Address rebase issues

* Templatize kernel instantiation, revise cpu bindings

* Safer writes to output

* Permit batch size > 1

* Numerical fixes for sdpa self attention

* Re-enable test, remove unused variable

* add benchmarking script

* Disable sdpa prior to perf tuning, and simplify tests for per-patch CI
2024-06-03 09:16:19 -07:00
Dominik Schlösser
3576b547c5 Doc error for default for scale in SinusoidalPositionalEncoding (#1174) 2024-06-02 13:42:45 -07:00
Awni Hannun
079882495d version bump (#1172) 2024-05-31 12:29:12 -07:00
K Venkat Ramnan
ab977109db feat: Added dlpack device (#1165)
* feat: Added dlpack device

* feat: Added device_id to dlpack device

* feat: Added device_id to dlpack device

* doc: updated conversion docs

* doc: updated numpy.rst dlpack information

* doc: updated numpy.rst dlpack information

* Update docs/src/usage/numpy.rst

* Update docs/src/usage/numpy.rst

---------

Co-authored-by: Venkat Ramnan Kalyanakumar <venkatramnankalyanakumar@Venkats-MacBook-Air.local>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2024-05-31 12:29:01 -07:00
Awni Hannun
fd1c08137b stable cumprod grad at 0 (#1167) 2024-05-31 12:28:42 -07:00
Jagrit Digani
76b6cece46 Fix multi-block sort stride management (#1169)
* Fix multi-block sort stride management

* Add seed to tests
2024-05-31 11:10:54 -07:00
Jagrit Digani
9f0df51f8d Fix matvec vector stride bug (#1168) 2024-05-29 12:18:28 -07:00
Awni Hannun
e7a2a3dcd1 Fix a couple bugs (#1161)
* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
2024-05-28 15:18:18 -07:00
Awni Hannun
a87ef5bfc1 fix broadcast bug in bitwise ops (#1157) 2024-05-24 11:44:40 -07:00
Awni Hannun
9f9cb7a2ef version bump (#1154) 2024-05-23 18:08:08 -07:00
Awni Hannun
7e26fd8032 Option to JIT steel gemm / conv (#1139) 2024-05-23 18:07:34 -07:00
Jagrit Digani
eab2685c67 Float mask update (#1152)
* Float mask update

* Update CPU impl
2024-05-23 17:20:44 -07:00
Angelos Katharopoulos
50dfb664db Comms (#1097)
* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
2024-05-23 17:04:02 -07:00
Awni Hannun
0189ab6ab6 More jitting (#1132)
* docs + circle min size build

* jit scan, arange, softmax

* add sort

* jit reductions

* remove print

* fix deps

* clean includes / nits
2024-05-23 16:23:44 -07:00
Rifur13
9401507336 Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
2024-05-22 20:01:44 -07:00
Awni Hannun
eb8321d863 list based indexing (#1150) 2024-05-22 15:52:05 -07:00
Abe Leininger
79ef49b2c2 add mx.trace (#1143) (#1147)
* working c++ trace implementation

* updated throw + added overloads

* added python binding for trace function

* pre-commit reformatting

* add trace to docs

* resolve comments

* remove to_stream call
2024-05-22 15:50:27 -07:00
Awni Hannun
e110ca11e2 Fix offset bug for device buffers (#1151)
* fix bug with large offsets for buffers

* add a test

* remove test as its too big for small machine
2024-05-22 15:50:05 -07:00
Awni Hannun
226748b3e7 JIT compile option for binary minimization (#1091)
* try cpp 20 for compile

* unary, binary, ternary in jit

* nits

* fix gather/scatter

* fix rebase

* reorg compile

* add ternary to compile

* jit copy

* jit compile flag

* fix build

* use linked function for ternary

* some nits

* docs + circle min size build

* docs + circle min size build

* fix extension

* fix no cpu build

* improve includes
2024-05-22 12:57:13 -07:00
Awni Hannun
d568c7ee36 Rename block sparse (#1149)
* block_sparse_mm to gather_mm

* rename

* nit

* nit
2024-05-22 07:48:34 -07:00
Awni Hannun
e6fecbb3e1 Some fixes in docs (#1141)
* fixes in docs

* nit
2024-05-20 11:51:47 -07:00
Angelos Katharopoulos
da83f899bb Improve qvm speed (#1140) 2024-05-20 09:20:44 -07:00
jlwitthuhn
7e5674d8be Treate 'minimum' differently in cosine decay (#1138) 2024-05-20 08:00:48 -07:00
Shixian Sheng
0a558577bf Update README.md (#1136) 2024-05-20 06:16:40 -07:00
Awni Hannun
fb71a82ada Fix copy bug with many dims (#1137) 2024-05-17 21:10:03 -07:00
Awni Hannun
23406c9e9e Choose the right MLX bf16 for extensions (#1135)
* default to custom bf

* choose right bf

* fix extensions

* fix circle conf
2024-05-17 15:09:28 -07:00
Luca Arnaboldi
b3ec792380 Implemented Cholesky on CPU (#1119) 2024-05-17 12:31:59 -07:00
209 changed files with 15985 additions and 7212 deletions

View File

@@ -71,6 +71,7 @@ jobs:
name: Install dependencies
command: |
brew install python@3.8
brew install openmpi
python3.8 -m venv env
source env/bin/activate
pip install --upgrade pip
@@ -96,10 +97,14 @@ jobs:
source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run:
name: Build example extension
command: |
cd examples/extensions && python3.8 -m pip install .
source env/bin/activate
cd examples/extensions
pip install -r requirements.txt
python setup.py build_ext -j8
- store_test_results:
path: test-results
- run:
@@ -111,7 +116,13 @@ jobs:
name: Run CPP tests
command: |
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
DEVICE=cpu ./build/tests/tests
- run:
name: Build small binary
command: |
source env/bin/activate
cd build/
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
make -j
build_release:
parameters:

View File

@@ -16,6 +16,7 @@ MLX was developed with contributions from the following individuals:
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -20,10 +20,11 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.13.1)
set(MLX_VERSION 0.15.0)
endif()
# --------------------- Processor tests -------------------------
@@ -109,7 +110,7 @@ elseif (MLX_BUILD_METAL)
$<INSTALL_INTERFACE:include/metal_cpp>
)
target_link_libraries(
mlx
mlx PUBLIC
${METAL_LIB}
${FOUNDATION_LIB}
${QUARTZ_LIB})
@@ -122,7 +123,7 @@ if (MLX_BUILD_CPU)
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
add_compile_definitions(ACCELERATE_NEW_LAPACK)
else()
message(STATUS "Accelerate or arm neon not found, using default backend.")
@@ -145,7 +146,7 @@ if (MLX_BUILD_CPU)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
target_link_libraries(mlx ${LAPACK_LIBRARIES})
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old version
# of lapack.h from the include dirs of blas.
find_package(BLAS REQUIRED)
@@ -160,12 +161,17 @@ if (MLX_BUILD_CPU)
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx ${BLAS_LIBRARIES})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
endif()
else()
set(MLX_BUILD_ACCELERATE OFF)
endif()
find_package(MPI)
if (MPI_FOUND)
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories(
@@ -175,6 +181,14 @@ target_include_directories(
$<INSTALL_INTERFACE:include>
)
FetchContent_Declare(fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
GIT_TAG 10.2.1
EXCLUDE_FROM_ALL
)
FetchContent_MakeAvailable(fmt)
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
if (MLX_BUILD_PYTHON_BINDINGS)
message(STATUS "Building Python bindings.")
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)

View File

@@ -88,13 +88,13 @@ for more information on building the C++ and Python APIs from source.
## Contributing
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
on contributing to MLX. See the
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
information on building from source, and running tests.
We are grateful for all of [our
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX and wish to be acknowledged, please add your name to the list in your
pull request.

View File

@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
def mish(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
return torch.nn.functional.mish(y)
y = torch.nn.functional.mish(y)
sync_if_needed(x)
@@ -283,6 +283,14 @@ def topk(axis, x):
sync_if_needed(x)
@torch.no_grad()
def step_function(x):
y = x
for i in range(100):
y = torch.where(y < 0, 0, 1)
sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
@@ -446,5 +454,11 @@ if __name__ == "__main__":
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else:
raise ValueError("Unknown benchmark")
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")

View File

@@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
result = run(*args, capture_output=True, **kwargs)
return float(result.stdout)
except ValueError:
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
raise ValueError(
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
)
def compare(args):

View File

@@ -9,7 +9,6 @@ from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@@ -51,7 +50,6 @@ def bench_gelu():
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)

View File

@@ -28,11 +28,11 @@ def bench(f, a, b):
return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding)
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
mx.eval(ys)
return ys
@@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad()
def pt_conv_2D(a, b):
ys = []
for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding)
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y)
torch.mps.synchronize()
return ys
@@ -53,11 +53,12 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
@@ -67,15 +68,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding)
f_pt = make_pt_conv_2D(strides, padding)
f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True)
@@ -84,7 +85,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
if not np.allclose(out_pt, out_mx, atol=atol):
print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
)
return time_mlx, time_torch
@@ -95,35 +96,40 @@ if __name__ == "__main__":
dtypes = ("float32",)
shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
)
for dtype in dtypes:
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
for N, H, W, C, kH, kW, O, strides, padding in shapes:
print(
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, np_dtype
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
)
diff = time_torch / time_mlx - 1.0
print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
)
if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^")

View File

@@ -3,6 +3,8 @@
import matplotlib
import mlx.core as mx
import numpy as np
import sympy
import torch
from time_utils import measure_runtime
matplotlib.use("Agg")
@@ -16,41 +18,100 @@ def bandwidth_gb(runtime_ms, system_size):
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size):
def fft(x):
out = mx.fft.fft(x)
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
def fft_mlx(x):
if dim == 1:
out = mx.fft.fft(x)
elif dim == 2:
out = mx.fft.fft2(x)
mx.eval(out)
return out
bandwidths = []
for k in range(4, 12):
n = 2**k
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
x = x.astype(mx.complex64)
mx.eval(x)
runtime_ms = measure_runtime(fft, x=x)
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
def fft_mps(x):
if dim == 1:
out = torch.fft.fft(x)
elif dim == 2:
out = torch.fft.fft2(x)
torch.mps.synchronize()
return out
return bandwidths
bandwidths = []
for n in fft_sizes:
batch_size = system_size // n**dim
shape = [batch_size] + [n for _ in range(dim)]
if backend == "mlx":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = mx.array(x_np)
mx.eval(x)
fft = fft_mlx
elif backend == "mps":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = torch.tensor(x_np, device="mps")
torch.mps.synchronize()
fft = fft_mps
else:
raise NotImplementedError()
runtime_ms = measure_runtime(fft, x=x)
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
print(n, bandwidth)
bandwidths.append(bandwidth)
return np.array(bandwidths)
def time_fft():
x = np.array(range(2, 512))
system_size = int(2**26)
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=int(2**22))
print("MLX GPU")
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29))
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
# plot bandwidths
x = [2**k for k in range(4, 12)]
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
plt.title("MLX FFT Benchmark")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig("fft_plot.png")
print("MPS GPU")
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
print("CPU")
system_size = int(2**20)
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
x = np.array(x)
all_indices = x - x[0]
radix_2to13 = (
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
)
bluesteins = (
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
)
for indices, name in [
(all_indices, "All"),
(radix_2to13, "Radix 2-13"),
(bluesteins, "Bluestein's"),
]:
# plot bandwidths
print(name)
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
plt.title(f"MLX FFT Benchmark -- {name}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"{name}.png")
plt.clf()
av_gpu_bandwidth = np.mean(gpu_bandwidths)
av_mps_bandwidth = np.mean(mps_bandwidths)
av_cpu_bandwidth = np.mean(cpu_bandwidths)
print("Average bandwidths:")
print("GPU:", av_gpu_bandwidth)
print("MPS:", av_mps_bandwidth)
print("CPU:", av_cpu_bandwidth)
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
print("Percent MLX faster than MPS: ", portion_faster * 100)
if __name__ == "__main__":

View File

@@ -0,0 +1,62 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
MAX_SEQ = 300
START_SEQ = 100
SEQ_INCREMENT = 50
def time_self_attention_primitives():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_primitives(qs, ks, vs, alpha):
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ vs
return o
time_fn(sdpa_primitives, q, k, v, scale)
def time_self_attention_sdpa():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
time_fn(sdpa_fused, q, k, v, scale)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_self_attention_sdpa()
time_self_attention_primitives()

View File

@@ -43,6 +43,7 @@ are the CPU and GPU.
usage/function_transforms
usage/compile
usage/numpy
usage/distributed
usage/using_streams
.. toctree::
@@ -69,6 +70,7 @@ are the CPU and GPU.
python/metal
python/nn
python/optimizers
python/distributed
python/tree_utils
.. toctree::

View File

@@ -163,6 +163,8 @@ should point to the path to the built metal library.
- ON
* - MLX_BUILD_GGUF
- ON
* - MLX_METAL_JIT
- OFF
.. note::
@@ -184,21 +186,30 @@ should point to the path to the built metal library.
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
and `BUILD_SHARED_LIBS=ON`.
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
GGUF, you can do:
```shell
cmake .. \
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=ON \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF
```
.. code-block:: shell
cmake ..
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
-DMLX_BUILD_SAFETENSORS=OFF \
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists accross reboots.
Troubleshooting
^^^^^^^^^^^^^^^

View File

@@ -0,0 +1,19 @@
.. _distributed:
.. currentmodule:: mlx.core.distributed
Distributed Communication
==========================
MLX provides a distributed communication package using MPI. The MPI library is
loaded at runtime; if MPI is available then distributed communication is also
made available.
.. autosummary::
:toctree: _autosummary
Group
is_available
init
all_sum
all_gather

View File

@@ -10,5 +10,6 @@ Linear Algebra
inv
norm
cholesky
qr
svd

View File

@@ -17,6 +17,8 @@ simple functions.
gelu_approx
gelu_fast_approx
glu
hard_shrink
hard_tanh
hardswish
leaky_relu
log_sigmoid
@@ -29,6 +31,7 @@ simple functions.
sigmoid
silu
softmax
softmin
softplus
softshrink
step

View File

@@ -21,10 +21,15 @@ Layers
Dropout3d
Embedding
GELU
GLU
GroupNorm
GRU
HardShrink
HardTanh
Hardswish
InstanceNorm
LayerNorm
LeakyReLU
Linear
LSTM
MaxPool1d
@@ -36,13 +41,19 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU6
RNN
RoPE
SELU
Sequential
SiLU
SinusoidalPositionalEncoding
Softmin
Softshrink
Softsign
Softmax
Softplus
Step
Tanh
Transformer
Upsample

View File

@@ -35,7 +35,6 @@ Operations
bitwise_or
bitwise_xor
block_masked_mm
block_sparse_mm
broadcast_to
ceil
clip
@@ -69,6 +68,8 @@ Operations
floor
floor_divide
full
gather_mm
gather_qmm
greater
greater_equal
identity
@@ -149,11 +150,13 @@ Operations
tensordot
tile
topk
trace
transpose
tri
tril
triu
var
view
where
zeros
zeros_like

View File

@@ -0,0 +1,166 @@
.. _usage_distributed:
Distributed Communication
=========================
.. currentmodule:: mlx.core.distributed
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
provide distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. You can
see a list of the supported operations in the :ref:`API docs<distributed>`.
.. note::
A lot of operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:
.. code:: python
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all
distributed processes. If simply run with ``python``, however, only one
process is launched and no distributed communication takes place.
To launch the program in distributed mode we need to use ``mpirun`` or
``mpiexec`` depending on the MPI installation. The simplest possible way is the
following:
.. code:: shell
$ mpirun -np 2 python test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
print 4 etc.
Installing MPI
---------------
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
Setting up Remote Hosts
-----------------------
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
full path to force all machines to use a specific path.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
.. note::
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
An easy way to pass the host names to MPI is using a host file. A host file
looks like the following, where ``host1`` and ``host2`` should be the fully
qualified domain names or IPs for these hosts.
.. code::
host1 slots=1
host2 slots=1
When using MLX, it is very likely that you want to use 1 slot per host, ie one
process per host. The hostfile also needs to contain the current
host if you want to run on the local host. Passing the host file to
``mpirun`` is simply done using the ``--hostfile`` command line argument.
Training Example
----------------
In this section we will adapt an MLX training loop to support data parallel
distributed training. Namely, we will average the gradients across a set of
hosts before applying them to the model.
Our training loop looks like the following code snippet if we omit the model,
dataset and optimizer initialization.
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
All we have to do to average the gradients across machines is perform an
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
have to :func:`mlx.utils.tree_map` the gradients with following function.
.. code:: python
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
Putting everything together our training loop step looks as follows with
everything else remaining the same.
.. code:: python
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
Tuning All Reduce
-----------------
We are working on improving the performance of all reduce on MLX but for now
the two main things one can do to extract the most out of distributed training with MLX are:
1. Perform a few large reductions instead of many small ones to improve
bandwidth and latency
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
connections between each host to improve bandwidth

View File

@@ -3,7 +3,11 @@
Conversion to NumPy and Other Frameworks
========================================
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
Let's convert an array to NumPy and back.
.. code-block:: python

View File

@@ -9,3 +9,4 @@ build_example(tutorial.cpp)
build_example(linear_regression.cpp)
build_example(logistic_regression.cpp)
build_example(metal_capture.cpp)
build_example(distributed.cpp)

View File

@@ -0,0 +1,22 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
if (!distributed::is_available()) {
std::cout << "No communication backend found" << std::endl;
return 1;
}
auto global_group = distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10});
array out = distributed::all_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@@ -21,4 +21,4 @@ python setup.py build_ext -j8 --inplace
```
python test.py
`
```

View File

@@ -25,6 +25,7 @@ else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)

View File

@@ -32,8 +32,6 @@ DEFAULT(ArgReduce)
DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT(Broadcast)
DEFAULT(Ceil)
DEFAULT(Concatenate)
@@ -49,6 +47,8 @@ DEFAULT(ErfInv)
DEFAULT(FFT)
DEFAULT(Floor)
DEFAULT(Gather)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Less)
@@ -80,6 +80,7 @@ DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);

View File

@@ -48,6 +48,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
@@ -56,6 +57,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)

View File

@@ -1,6 +1,8 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"

View File

@@ -0,0 +1,101 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"
#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif
namespace mlx::core {
namespace {
// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
} // namespace
void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
// (A)ᵀ = A
// and that a column-major lower triangular matrix is a row-major upper
// triangular matrix, so uplo is the opposite of what we would expect from
// upper
char uplo = (upper) ? 'L' : 'U';
// The decomposition is computed in place, so just copy the input to the
// output.
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
float* matrix = factor.data<float>();
for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N);
// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[cholesky] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}
// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
}
matrix += N;
}
}
}
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32.");
}
cholesky_impl(inputs[0], output, upper_);
}
} // namespace mlx::core

View File

@@ -250,49 +250,6 @@ void Split::eval(
}
}
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
const array& in) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;

View File

@@ -111,13 +111,17 @@ void slow_conv_2D(
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int C = in.shape(3); // In channels
const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim
const int groups = C / wt.shape(3);
const int C_per_group = wt.shape(3);
const int O_per_group = O / groups;
const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1];
const size_t in_stride_W = in.strides()[2];
@@ -141,33 +145,35 @@ void slow_conv_2D(
int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = 0; wh < wH; ++wh) {
for (int ww = 0; ww < wW; ++ww) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ww
} // wh
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
@@ -219,41 +225,43 @@ void slow_conv_2D(
int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) {
float r = 0.;
for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
int wh_flip = flip ? wH - wh - 1 : wh;
int ww_flip = flip ? wW - ww - 1 : ww;
int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1];
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
const T* wt_ptr_pt =
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) {
r += static_cast<float>(in_ptr_pt[0]) *
static_cast<float>(wt_ptr_pt[0]);
in_ptr_pt += in_stride_C;
wt_ptr_pt += wt_stride_C;
} // c
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
++c) {
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c
} // ih, iw check
} // ww
} // wh
} // ih, iw check
} // ww
} // wh
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
out_ptr[0] = static_cast<T>(r);
out_ptr += out_stride_O;
wt_ptr += wt_stride_O;
} // o
} // g
};
int oH_border_0 = 0;

View File

@@ -256,7 +256,7 @@ void copy_general_general(
}
int size = std::accumulate(
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) {
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);

View File

@@ -5,7 +5,6 @@
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
@@ -43,8 +42,8 @@ DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BlockMaskedMM)
DEFAULT(BlockSparseMM)
DEFAULT(BlockSparseQMM)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT_MULTI(DivMod)
DEFAULT(Ceil)
DEFAULT(Concatenate)
@@ -113,6 +112,7 @@ DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
namespace {

View File

@@ -28,6 +28,7 @@ const char* get_kernel_preamble() {
return R"preamble(
$INCLUDES
$CONTENT
using namespace mlx::core;
using namespace mlx::core::detail;
)preamble";
}

View File

@@ -17,24 +17,25 @@ namespace mlx::core {
namespace {
template <typename T>
template <typename T, typename mask_t>
inline void mask_matrix(
T* data,
const bool* mask,
const mask_t* mask,
int block_size,
const int X,
const int Y,
const size_t X_data_str,
const size_t Y_data_str,
const size_t X_mask_str,
const size_t Y_mask_str) {
const size_t Y_mask_str,
const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size;
for (int i = 0; i < tX; i++) {
for (int j = 0; j < tY; j++) {
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
if (!do_mask) {
mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
if (do_mask != 1) {
int loc_x = i * block_size;
int loc_y = j * block_size;
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
@@ -43,7 +44,11 @@ inline void mask_matrix(
int size_y = std::min(block_size, Y - loc_y);
for (int ii = 0; ii < size_x; ii++) {
for (int jj = 0; jj < size_y; jj++) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
if constexpr (std::is_same_v<mask_t, bool>) {
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
} else {
data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
}
}
}
}
@@ -62,36 +67,39 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
auto& out_mask = inputs[2];
auto check_transpose = [](const array& arr, bool do_copy) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
}
return std::make_tuple(false, stx, arr);
} else if (stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
}
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
auto check_transpose =
[](const array& arr, bool do_copy, bool expand_all = false) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(false, stx, arr_copy);
}
return std::make_tuple(false, stx, arr);
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
if (do_copy) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::Vector);
return std::make_tuple(true, sty, arr_copy);
}
return std::make_tuple(true, sty, arr);
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy);
}
};
bool has_op_mask = inputs.size() > 3;
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
auto [a_transposed, lda, a] =
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
auto [b_transposed, ldb, b] =
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
@@ -114,27 +122,42 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
int Y,
size_t X_data_str,
size_t Y_data_str) {
const bool* mask_ptr = mask.data<bool>() +
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(),
mask.strides());
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
return mask_matrix(
data,
mask_ptr,
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str);
if (mask.dtype() == bool_) {
return mask_matrix(
data,
mask.data<bool>(),
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str,
mask_offset);
} else {
return mask_matrix(
data,
mask.data<float>(),
block_size,
X,
Y,
X_data_str,
Y_data_str,
X_mask_str,
Y_mask_str,
mask_offset);
}
};
for (int i = 0; i < (a.size() / (M * K)); ++i) {
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
// Adjust pointer
float* ai =
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
@@ -144,7 +167,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
// Zero out blocks in a and b if needed
if (has_op_mask) {
auto& a_mask = inputs[3];
auto& a_mask = inputs[inputs.size() - 2];
mask_array(
a_mask,
ai,
@@ -155,7 +178,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
a_transposed ? 1 : lda,
a_transposed ? lda : 1);
auto& b_mask = inputs[4];
auto& b_mask = inputs[inputs.size() - 1];
mask_array(
b_mask,
bi,
@@ -186,14 +209,16 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
);
// Zero out blocks in out
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
if (has_out_mask) {
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
}
}
}
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[BlockSparseMM::eval] Currently only supports float32.");
"[GatherMM::eval] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
@@ -277,4 +302,4 @@ void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
}
}
} // namespace mlx::core
} // namespace mlx::core

View File

@@ -11,6 +11,7 @@
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/utils.h"
@@ -492,7 +493,8 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
auto [copy_needed, data_offset, inp_strides] =
prepare_slice(in, start_indices_, strides_);
// Do copy if needed
if (copy_needed) {
@@ -590,4 +592,36 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
}
}
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * obytes / ibytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
copy_inplace(in, tmp, CopyType::General);
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@@ -357,7 +357,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
}
void BlockSparseQMM::eval(const std::vector<array>& inputs, array& out) {
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
auto& x_pre = inputs[0];

View File

@@ -234,7 +234,7 @@ void scan_dispatch(
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::max();
: std::numeric_limits<U>::min();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);

View File

@@ -0,0 +1,52 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
std::vector<int>& start_indices,
std::vector<int>& strides) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
} // namespace mlx::core

View File

@@ -0,0 +1,20 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
std::vector<int>& start_indices,
std::vector<int>& strides);
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out);
} // namespace mlx::core

View File

@@ -1,33 +1,130 @@
add_custom_command(
OUTPUT compiled_preamble.cpp
function(make_jit_source SRC_FILE)
# This function takes a metal header file,
# runs the C preprocessesor on it, and makes
# the processed contents available as a string in a C++ function
# mlx::core::metal::${SRC_NAME}()
#
# To use the function, declare it in jit/includes.h and
# include jit/includes.h.
#
# Additional arguments to this function are treated as dependencies
# in the Cmake build system.
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
add_custom_command(
OUTPUT jit/${SRC_NAME}.cpp
COMMAND /bin/bash
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_CURRENT_BINARY_DIR}/jit
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
${SRC_FILE}
"-D${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/compiled_preamble.h
kernels/unary.h
kernels/binary.h
kernels/bf16.h
kernels/erf.h
kernels/expm1f.h
kernels/utils.h
kernels/bf16_math.h
)
kernels/${SRC_FILE}.h
${ARGN}
)
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME})
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
)
endfunction(make_jit_source)
add_custom_target(
compiled_preamble
DEPENDS compiled_preamble.cpp
make_jit_source(
utils
kernels/bf16.h
kernels/complex.h
kernels/defines.h
)
make_jit_source(
unary_ops
kernels/erf.h
kernels/expm1f.h
)
make_jit_source(binary_ops)
make_jit_source(ternary_ops)
make_jit_source(
reduce_utils
kernels/atomic.h
kernels/reduction/ops.h
)
make_jit_source(scatter)
make_jit_source(gather)
add_dependencies(mlx compiled_preamble)
if (MLX_METAL_JIT)
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
)
make_jit_source(arange)
make_jit_source(copy)
make_jit_source(unary)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(
fft
kernels/fft/radix.h
kernels/fft/readwrite.h
)
make_jit_source(ternary)
make_jit_source(softmax)
make_jit_source(scan)
make_jit_source(sort)
make_jit_source(
reduce
kernels/reduction/reduce_all.h
kernels/reduction/reduce_col.h
kernels/reduction/reduce_row.h
)
make_jit_source(
steel/gemm/gemm
kernels/steel/utils.h
kernels/steel/gemm/loader.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/params.h
kernels/steel/gemm/transforms.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
make_jit_source(
steel/gemm/kernels/steel_gemm_masked
kernels/steel/defines.h
)
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
make_jit_source(
steel/conv/conv
kernels/steel/utils.h
kernels/steel/defines.h
kernels/steel/gemm/mma.h
kernels/steel/gemm/transforms.h
kernels/steel/conv/params.h
kernels/steel/conv/loader.h
kernels/steel/conv/loaders/loader_channel_l.h
kernels/steel/conv/loaders/loader_channel_n.h
)
make_jit_source(
steel/conv/kernels/steel_conv
)
make_jit_source(
steel/conv/kernels/steel_conv_general
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
)
endif()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
@@ -43,10 +140,12 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
)
if (NOT MLX_METAL_PATH)

View File

@@ -0,0 +1,360 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/binary.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
auto& out = outputs[0];
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& d = metal::device(s.device);
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
compute_encoder.set_output_array(outputs[0], 2);
compute_encoder.set_output_array(outputs[1], 3);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 7);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
binary_op_gpu_inplace(inputs, outputs, op, s);
}
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op) {
auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s);
}
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
auto& strides_a = strides[0];
auto& strides_b = strides[1];
auto& strides_out = strides[2];
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_output_array(out, 2);
if (bopt == BinaryOpType::General) {
auto ndim = shape.size();
if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
} else {
// The shape is implicit in the grid for <= 3D
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 6);
}
// Launch up to 3D grid of threads
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = out.size() / (dim0 * dim1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D grid of threads
size_t nthreads =
bopt == BinaryOpType::General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
}
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
}
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
binary_op_gpu_inplace(inputs, out, op, s);
}
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op) {
auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s);
}
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "add");
}
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "arctan2");
}
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu(inputs, out, "bitwise_and");
break;
case BitwiseBinary::Or:
binary_op_gpu(inputs, out, "bitwise_or");
break;
case BitwiseBinary::Xor:
binary_op_gpu(inputs, out, "bitwise_xor");
break;
case BitwiseBinary::LeftShift:
binary_op_gpu(inputs, out, "left_shift");
break;
case BitwiseBinary::RightShift:
binary_op_gpu(inputs, out, "right_shift");
break;
}
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "div");
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
binary_op_gpu(inputs, outputs, "divmod");
}
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "rem");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, equal_nan_ ? "naneq" : "eq");
}
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "ge");
}
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "geq");
}
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "le");
}
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "leq");
}
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "land");
}
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "lor");
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "lae");
}
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "max");
}
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "min");
}
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "mul");
}
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "neq");
}
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "pow");
}
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op_gpu(inputs, out, "sub");
}
} // namespace mlx::core

View File

@@ -0,0 +1,33 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op,
const Stream& s);
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string op,
const Stream& s);
} // namespace mlx::core

View File

@@ -4,8 +4,8 @@
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/compiled_preamble.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h"
@@ -56,12 +56,15 @@ inline void build_kernel(
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl
<< " constant const size_t* " << xname << "_strides [[buffer("
<< cnt++ << ")]]," << std::endl;
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
}
}
if (add_indices) {
os << " constant const size_t* in_strides [[buffer(" << cnt++
<< ")]],\n";
}
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
@@ -110,13 +113,17 @@ inline void build_kernel(
}
// Read the inputs in tmps
for (auto& x : inputs) {
int nc_in_count = 0;
for (int i = 0; i < inputs.size(); ++i) {
auto& x = inputs[i];
auto& xname = namer.get_name(x);
if (is_constant(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
auto type_str = get_type_string(x.dtype());
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
print_constant(os, x);
os << ";" << std::endl;
os << ");" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
@@ -124,17 +131,20 @@ inline void build_kernel(
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
int offset = nc_in_count * ndim;
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << xname << "_strides[0]";
os << "index_0 * " << "in_strides[" << offset << "]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
}
os << "];" << std::endl;
nc_in_count++;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, " << xname
<< "_strides, ndim)];" << std::endl;
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
<< nc_in_count * ndim << ", ndim)];" << std::endl;
nc_in_count++;
}
}
@@ -190,7 +200,8 @@ void Compiled::eval_gpu(
// If not we have to build it ourselves
if (lib == nullptr) {
std::ostringstream kernel;
kernel << metal::get_kernel_preamble() << std::endl;
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
<< metal::ternary_ops();
build_kernel(
kernel,
kernel_lib_ + "_contiguous",
@@ -295,6 +306,7 @@ void Compiled::eval_gpu(
// Put the inputs in
int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides
std::vector<size_t> in_strides;
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
@@ -302,13 +314,17 @@ void Compiled::eval_gpu(
auto& x = inputs[i];
compute_encoder.set_input_array(x, cnt++);
if (!contiguous && !is_scalar(x)) {
compute_encoder->setBytes(
strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t),
cnt++);
in_strides.insert(
in_strides.end(),
strides[stride_idx].begin(),
strides[stride_idx].end());
stride_idx++;
}
}
if (!in_strides.empty()) {
compute_encoder->setBytes(
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, true);

View File

@@ -1,9 +0,0 @@
// Copyright © 2023-24 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* get_kernel_preamble();
}

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/matmul.h"
@@ -257,15 +258,19 @@ void implicit_gemm_conv_2D_gpu(
const array& wt,
array out,
const MLXConvParams<2>& conv_params) {
const int groups = conv_params.groups;
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
// Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
const int implicit_N = O_per_group;
const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
// Determine block and warp tiles
int wm = 2, wn = 2;
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
int bk = 16;
@@ -281,15 +286,15 @@ void implicit_gemm_conv_2D_gpu(
// Fix small channel specialization
int n_channel_specialization = 0;
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
int channel_k_iters = ((C_per_group + bk - 1) / bk);
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
if (conv_params.C <= 2) {
if (C_per_group <= 2) {
gemm_k_iters = (implicit_K + bk - 1) / bk;
n_channel_specialization = conv_params.C;
} else if (conv_params.C <= 4) {
n_channel_specialization = C_per_group;
} else if (C_per_group <= 4) {
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
n_channel_specialization = conv_params.C;
n_channel_specialization = C_per_group;
}
bool small_filter = (!n_channel_specialization) &&
@@ -331,7 +336,17 @@ void implicit_gemm_conv_2D_gpu(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = get_steel_conv_kernel(
d,
kname.str(),
out,
bm,
bn,
bk,
wm,
wn,
n_channel_specialization,
small_filter);
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
@@ -340,7 +355,7 @@ void implicit_gemm_conv_2D_gpu(
size_t grid_dim_x = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
// Encode arrays
compute_encoder.set_input_array(in, 0);
@@ -484,7 +499,8 @@ void implicit_gemm_conv_2D_general_gpu(
// Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder->setComputePipelineState(kernel);
// Deduce grid launch dimensions
@@ -703,6 +719,7 @@ void conv_2D_gpu(
const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation,
const int groups,
bool flip,
std::vector<array>& copies) {
// Make conv params
@@ -718,12 +735,12 @@ void conv_2D_gpu(
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
/* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
{in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
/* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
{wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
/* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
/* const int groups = */ 1,
{out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
/* const int groups = */ groups,
/* const bool flip = */ flip,
};
@@ -735,6 +752,18 @@ void conv_2D_gpu(
bool channels_large = (conv_params.C + conv_params.O) >= 512;
bool channels_med = (conv_params.C + conv_params.O) >= 256;
if (groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
@@ -860,6 +889,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_strides_,
kernel_dilation_,
input_dilation_,
groups_,
flip_,
copies);
}

View File

@@ -4,12 +4,14 @@
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
@@ -31,9 +33,6 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
if (out.size() == 0) {
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
@@ -55,6 +54,10 @@ void copy_gpu_inplace(
int64_t out_offset,
CopyType ctype,
const Stream& s) {
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector{strides_in_pre, strides_out_pre});
@@ -62,27 +65,34 @@ void copy_gpu_inplace(
auto& strides_out_ = strides[1];
auto& d = metal::device(s.device);
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << "scopy";
break;
case CopyType::Vector:
kname << "vcopy";
break;
case CopyType::General:
kname << "gcopy";
break;
case CopyType::GeneralGeneral:
kname << "ggcopy";
break;
std::string kernel_name;
{
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << "s";
break;
case CopyType::Vector:
kname << "v";
break;
case CopyType::General:
kname << "g";
break;
case CopyType::GeneralGeneral:
kname << "gg";
break;
}
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
}
kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
kernel_name = kname.str();
}
kname << type_to_name(in) << type_to_name(out);
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << "_" << shape.size();
}
auto kernel = d.get_kernel(kname.str());
auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_in = in.data_shared_ptr() == nullptr;
@@ -106,7 +116,7 @@ void copy_gpu_inplace(
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
}
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5);
}

View File

@@ -285,7 +285,6 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
NS::Error* error = nullptr;
auto options = MTL::CompileOptions::alloc()->init();
options->setFastMathEnabled(false);
options->setLanguageVersion(get_metal_version());
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
options->release();

View File

@@ -63,7 +63,7 @@ struct CommandEncoder {
return enc;
}
void set_input_array(const array& a, int idx, int offset = 0) {
void set_input_array(const array& a, int idx, int64_t offset = 0) {
auto r_buf =
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
@@ -80,7 +80,7 @@ struct CommandEncoder {
enc->setBuffer(a_buf, base_offset, idx);
}
void set_output_array(array& a, int idx, int offset = 0) {
void set_output_array(array& a, int idx, int64_t offset = 0) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());

View File

@@ -1,106 +1,794 @@
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <complex>
#include <map>
#include <numeric>
#include <set>
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/backend/metal/binary.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/mlx.h"
#include "mlx/primitives.h"
namespace mlx::core {
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
auto& in = inputs[0];
#define MAX_STOCKHAM_FFT_SIZE 4096
#define MAX_RADER_FFT_SIZE 2048
#define MAX_BLUESTEIN_FFT_SIZE 2048
// Threadgroup memory batching improves throughput for small n
#define MIN_THREADGROUP_MEM_SIZE 256
// For strided reads/writes, coalesce at least this many complex64s
#define MIN_COALESCE_WIDTH 4
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
in.dtype() != complex64 || out.dtype() != complex64) {
// Could also fallback to CPU implementation here.
throw std::runtime_error(
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
inline const std::vector<int> supported_radices() {
// Ordered by preference in decomposition.
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
}
std::vector<int> prime_factors(int n) {
int z = 2;
std::vector<int> factors;
while (z * z <= n) {
if (n % z == 0) {
factors.push_back(z);
n /= z;
} else {
z++;
}
}
if (n > 1) {
factors.push_back(n);
}
return factors;
}
struct FourStepParams {
bool required = false;
bool first_step = true;
int n1 = 0;
int n2 = 0;
};
// Forward Declaration
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
const FourStepParams four_step_params,
bool inplace,
const Stream& s);
struct FFTPlan {
int n = 0;
// Number of steps for each radix in the Stockham decomposition
std::vector<int> stockham;
// Number of steps for each radix in the Rader decomposition
std::vector<int> rader;
// Rader factor, 1 if no rader factors
int rader_n = 1;
int bluestein_n = -1;
// Four step FFT
bool four_step = false;
int n1 = 0;
int n2 = 0;
};
int next_fast_n(int n) {
return next_power_of_2(n);
}
std::vector<int> plan_stockham_fft(int n) {
auto radices = supported_radices();
std::vector<int> plan(radices.size(), 0);
int orig_n = n;
if (n == 1) {
return plan;
}
for (int i = 0; i < radices.size(); i++) {
int radix = radices[i];
// Manually tuned radices for powers of 2
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
continue;
}
while (n % radix == 0) {
plan[i] += 1;
n /= radix;
if (n == 1) {
return plan;
}
}
}
throw std::runtime_error("Unplannable");
}
FFTPlan plan_fft(int n) {
auto radices = supported_radices();
std::set<int> radices_set(radices.begin(), radices.end());
FFTPlan plan;
plan.n = n;
plan.rader = std::vector<int>(radices.size(), 0);
auto factors = prime_factors(n);
int remaining_n = n;
// Four Step FFT when N is too large for shared mem.
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
// For power's of two we have a fast, no transpose four step implementation.
plan.four_step = true;
// Rough heuristic for choosing faster powers of two when we can
plan.n2 = n > 65536 ? 1024 : 64;
plan.n1 = n / plan.n2;
return plan;
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
// Otherwise we use a multi-upload Bluestein's
plan.four_step = true;
plan.bluestein_n = next_fast_n(2 * n - 1);
return plan;
}
size_t n = in.shape(axes_[0]);
for (int factor : factors) {
// Make sure the factor is a supported radix
if (radices_set.find(factor) == radices_set.end()) {
// We only support a single Rader factor currently
// TODO(alexbarron) investigate weirdness with large
// Rader sizes -- possibly a compiler issue?
if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) {
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
plan.bluestein_n = next_fast_n(2 * n - 1);
plan.stockham = plan_stockham_fft(plan.bluestein_n);
plan.rader = std::vector<int>(radices.size(), 0);
return plan;
}
// See if we can use Rader's algorithm to Stockham decompose n - 1
auto rader_factors = prime_factors(factor - 1);
int last_factor = -1;
for (int rf : rader_factors) {
// We don't nest Rader's algorithm so if `factor - 1`
// isn't Stockham decomposable we give up and do Bluestein's.
if (radices_set.find(rf) == radices_set.end()) {
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
plan.bluestein_n = next_fast_n(2 * n - 1);
plan.stockham = plan_stockham_fft(plan.bluestein_n);
plan.rader = std::vector<int>(radices.size(), 0);
return plan;
}
}
plan.rader = plan_stockham_fft(factor - 1);
plan.rader_n = factor;
remaining_n /= factor;
}
}
if (!is_power_of_2(n) || n > 2048 || n < 4) {
throw std::runtime_error(
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
plan.stockham = plan_stockham_fft(remaining_n);
return plan;
}
int compute_elems_per_thread(FFTPlan plan) {
// Heuristics for selecting an efficient number
// of threads to use for a particular mixed-radix FFT.
auto n = plan.n;
std::vector<int> steps;
auto radices = supported_radices();
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
std::set<int> used_radices;
for (int i = 0; i < steps.size(); i++) {
int radix = radices[i % radices.size()];
if (steps[i] > 0) {
used_radices.insert(radix);
}
}
// Manual tuning for 7/11/13
if (used_radices.find(7) != used_radices.end() &&
(used_radices.find(11) != used_radices.end() ||
used_radices.find(13) != used_radices.end())) {
return 7;
} else if (
used_radices.find(11) != used_radices.end() &&
used_radices.find(13) != used_radices.end()) {
return 11;
}
// TODO(alexbarron) Some really weird stuff is going on
// for certain `elems_per_thread` on large composite n.
// Possibly a compiler issue?
if (n == 3159)
return 13;
if (n == 3645)
return 5;
if (n == 3969)
return 7;
if (n == 1982)
return 5;
if (used_radices.size() == 1) {
return *(used_radices.begin());
}
if (used_radices.size() == 2) {
if (used_radices.find(11) != used_radices.end() ||
used_radices.find(13) != used_radices.end()) {
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
}
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
return radix_vec[1];
}
// In all other cases use the second smallest radix.
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
return radix_vec[1];
}
// Rader
int mod_exp(int x, int y, int n) {
int out = 1;
while (y) {
if (y & 1) {
out = out * x % n;
}
y >>= 1;
x = x * x % n;
}
return out;
}
int primitive_root(int n) {
auto factors = prime_factors(n - 1);
for (int r = 2; r < n - 1; r++) {
bool found = true;
for (int factor : factors) {
if (mod_exp(r, (n - 1) / factor, n) == 1) {
found = false;
break;
}
}
if (found) {
return r;
}
}
return -1;
}
std::tuple<array, array, array> compute_raders_constants(
int rader_n,
const Stream& s) {
int proot = primitive_root(rader_n);
// Fermat's little theorem
int inv = mod_exp(proot, rader_n - 2, rader_n);
std::vector<short> g_q(rader_n - 1);
std::vector<short> g_minus_q(rader_n - 1);
for (int i = 0; i < rader_n - 1; i++) {
g_q[i] = mod_exp(proot, i, rader_n);
g_minus_q[i] = mod_exp(inv, i, rader_n);
}
array g_q_arr(g_q.begin(), {rader_n - 1});
array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1});
std::vector<std::complex<float>> b_q(rader_n - 1);
for (int i = 0; i < rader_n - 1; i++) {
float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n;
b_q[i] = std::exp(std::complex<float>(0, pi_i));
}
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
auto b_q_fft_ptr =
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
std::ptrdiff_t item_size = b_q_fft.itemsize();
size_t fft_size = rader_n - 1;
// This FFT is always small (<4096, batch 1) so save some overhead
// and do it on the CPU
pocketfft::c2c(
/* shape= */ {fft_size},
/* stride_in= */ {item_size},
/* stride_out= */ {item_size},
/* axes= */ {0},
/* forward= */ true,
/* data_in= */ b_q.data(),
/* data_out= */ b_q_fft_ptr,
/* scale= */ 1.0f);
return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr);
}
// Bluestein
std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
// We need to calculate the Bluestein twiddle factors
// in double precision for the overall numerical stability
// of Bluestein's FFT algorithm to be acceptable.
//
// Metal doesn't support float64, so instead we
// manually implement the required operations on cpu.
//
// In numpy:
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
// w_q = np.fft.fft(1/w_k)
// return w_k, w_q
int length = 2 * n - 1;
std::vector<std::complex<float>> w_k_vec(n);
std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
for (int i = -n + 1; i < n; i++) {
double theta = pow(i, 2) * M_PI / (double)n;
w_q_vec[i + n - 1] = std::exp(std::complex<double>(0, theta));
if (i >= 0) {
w_k_vec[i] = std::exp(std::complex<double>(0, -theta));
}
}
array w_k({n}, complex64, nullptr, {});
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
array w_q({bluestein_n}, complex64, nullptr, {});
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
auto w_q_ptr =
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
std::ptrdiff_t item_size = w_q.itemsize();
size_t fft_size = bluestein_n;
pocketfft::c2c(
/* shape= */ {fft_size},
/* stride_in= */ {item_size},
/* stride_out= */ {item_size},
/* axes= */ {0},
/* forward= */ true,
/* data_in= */ w_q_vec.data(),
/* data_out= */ w_q_ptr,
/* scale= */ 1.0f);
return std::make_tuple(w_k, w_q);
}
void multi_upload_bluestein_fft(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
const Stream& s) {
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
// algorithm
int n = inverse ? out.shape(axis) : in.shape(axis);
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
// Broadcast w_q and w_k to the batch size
std::vector<size_t> b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
auto temp_shape = inverse ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
array temp1(temp_shape, complex64, nullptr, {});
if (real && !inverse) {
// Convert float32->complex64
copy_gpu(in, temp, CopyType::General, s);
} else if (real && inverse) {
int back_offset = n % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape();
slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {});
copies.push_back(slice_temp);
copies.push_back(conj_temp);
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
rstarts[axis] = in.shape(axis) - back_offset;
rstrides[axis] = -1;
unary_op_gpu({in}, conj_temp, "conj", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
} else if (inverse) {
unary_op_gpu({in}, temp, "conj", s);
} else {
temp.copy_shared_buffer(in);
}
binary_op_gpu({temp, w_k_broadcast}, temp1, "mul", s);
std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_n;
array pad_temp(padded_shape, complex64, nullptr, {});
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
array pad_temp1(padded_shape, complex64, nullptr, {});
fft_op(
pad_temp,
pad_temp1,
axis,
/*inverse=*/false,
/*real=*/false,
FourStepParams(),
/*inplace=*/false,
s);
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "mul", s);
fft_op(
pad_temp,
pad_temp1,
axis,
/* inverse= */ true,
/* real= */ false,
FourStepParams(),
/*inplace=*/true,
s);
int offset = plan.bluestein_n - (2 * n - 1);
std::vector<int> starts(in.ndim(), 0);
std::vector<int> strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "mul", s);
if (real && !inverse) {
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
slice_gpu(temp1, out, rstarts, strides, s);
} else if (real && inverse) {
std::vector<size_t> b_strides(in.ndim(), 0);
auto inv_n = array({1.0f / n}, {1}, float32);
array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float);
copies.push_back(inv_n);
copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "mul", s);
} else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64);
unary_op_gpu({temp1}, temp, "conj", s);
binary_op_gpu({temp, inv_n}, out, "mul", s);
copies.push_back(inv_n);
} else {
out.copy_shared_buffer(temp1);
}
copies.push_back(w_k);
copies.push_back(w_q);
copies.push_back(w_k_broadcast);
copies.push_back(w_q_broadcast);
copies.push_back(temp);
copies.push_back(temp1);
copies.push_back(pad_temp);
copies.push_back(pad_temp1);
}
void four_step_fft(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
const Stream& s) {
auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) {
// Fast no transpose implementation for powers of 2.
FourStepParams four_step_params = {
/* required= */ true, /* first_step= */ true, plan.n1, plan.n2};
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
fft_op(
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
four_step_params.first_step = false;
fft_op(
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
copies.push_back(temp);
} else {
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
}
}
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
const FourStepParams four_step_params,
bool inplace,
const Stream& s) {
auto& d = metal::device(s.device);
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
if (n == 1) {
out.copy_shared_buffer(in);
return;
}
if (four_step_params.required) {
// Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows
n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2;
}
// Make sure that the array is contiguous and has stride 1 in the FFT dim
std::vector<array> copies;
auto check_input = [this, &copies, &s](const array& x) {
auto check_input = [&axis, &copies, &s](const array& x) {
// TODO: Pass the strides to the kernel so
// we can avoid the copy when x is not contiguous.
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
x.flags().col_contiguous;
bool no_copy = x.strides()[axis] == 1 &&
(x.flags().row_contiguous || x.flags().col_contiguous);
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
std::vector<size_t> strides;
size_t cur_stride = x.shape(axes_[0]);
for (int axis = 0; axis < x.ndim(); axis++) {
if (axis == axes_[0]) {
size_t cur_stride = x.shape(axis);
for (int a = 0; a < x.ndim(); a++) {
if (a == axis) {
strides.push_back(1);
} else {
strides.push_back(cur_stride);
cur_stride *= x.shape(axis);
cur_stride *= x.shape(a);
}
}
auto flags = x.flags();
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
f_stride *= x.shape(i);
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
b_stride *= x.shape(ri);
}
// This is probably over-conservative
flags.contiguous = false;
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(x.shape(), strides);
flags.col_contiguous = is_row_contiguous;
flags.row_contiguous = is_col_contiguous;
flags.contiguous = data_size == x_copy.size();
x_copy.set_data(
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& in_contiguous = check_input(inputs[0]);
const array& in_contiguous = check_input(in);
// real to complex: n -> (n/2)+1
// complex to real: (n/2)+1 -> n
auto out_strides = in_contiguous.strides();
size_t out_data_size = in_contiguous.data_size();
if (in.shape(axis) != out.shape(axis)) {
for (int i = 0; i < out_strides.size(); i++) {
if (out_strides[i] != 1) {
out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis);
}
}
out_data_size = out_data_size / in.shape(axis) * out.shape(axis);
}
auto plan = plan_fft(n);
if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
// TODO: allow donation here
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
in_contiguous.data_size(),
in_contiguous.strides(),
in_contiguous.flags());
if (!inplace) {
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
out_data_size,
out_strides,
in_contiguous.flags());
}
// We use n / 4 threads by default since radix-4
// is the largest single threaded radix butterfly
// we currently implement.
size_t m = n / 4;
size_t batch = in.size() / in.shape(axes_[0]);
auto radices = supported_radices();
int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n;
// Setup function constants
bool power_of_2 = is_power_of_2(fft_size);
auto make_int = [](int* a, int i) {
return std::make_tuple(a, MTL::DataType::DataTypeInt, i);
};
auto make_bool = [](bool* a, int i) {
return std::make_tuple(a, MTL::DataType::DataTypeBool, i);
};
std::vector<MTLFC> func_consts = {
make_bool(&inverse, 0), make_bool(&power_of_2, 1)};
// Start of radix/rader step constants
int index = 4;
for (int i = 0; i < plan.stockham.size(); i++) {
func_consts.push_back(make_int(&plan.stockham[i], index));
index += 1;
}
for (int i = 0; i < plan.rader.size(); i++) {
func_consts.push_back(make_int(&plan.rader[i], index));
index += 1;
}
int elems_per_thread = compute_elems_per_thread(plan);
func_consts.push_back(make_int(&elems_per_thread, 2));
int rader_m = n / plan.rader_n;
func_consts.push_back(make_int(&rader_m, 3));
// The overall number of FFTs we're going to compute for this input
int size = out.dtype() == float32 ? out.size() : in.size();
if (real && inverse && four_step_params.required) {
size = out.size();
}
int total_batch_size = size / n;
int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread;
// We batch among threadgroups for improved efficiency when n is small
int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1);
if (four_step_params.required) {
// Require a threadgroup batch size of at least 4 for four step FFT
// so we can coalesce the memory accesses.
threadgroup_batch_size =
std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH);
}
int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size);
// FFTs up to 2^20 are currently supported
assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE);
// ceil divide
int batch_size =
(total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size;
if (real && !four_step_params.required) {
// We can perform 2 RFFTs at once so the batch size is halved.
batch_size = (batch_size + 2 - 1) / 2;
}
int out_buffer_size = out.size();
auto& compute_encoder = d.get_command_encoder(s.index);
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
// Only required by four step
int step = -1;
{
std::ostringstream kname;
kname << "fft_" << n;
auto kernel = d.get_kernel(kname.str());
std::string inv_string = inverse ? "true" : "false";
std::string real_string = real ? "true" : "false";
if (plan.bluestein_n > 0) {
kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
<< in_type_str << "_" << out_type_str;
} else if (plan.rader_n > 1) {
kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str;
} else if (four_step_params.required) {
step = four_step_params.first_step ? 0 : 1;
kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str << "_" << step << "_" << real_string;
} else {
kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
<< out_type_str;
}
std::string base_name = kname.str();
// We use a specialized kernel for each FFT size
kname << "_n" << fft_size << "_inv_" << inverse;
std::string hash_name = kname.str();
auto kernel = get_fft_kernel(
d,
base_name,
hash_name,
threadgroup_mem_size,
in_type_str,
out_type_str,
step,
real,
func_consts);
bool donated = in.data_shared_ptr() == nullptr;
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1);
auto group_dims = MTL::Size(1, m, 1);
auto grid_dims = MTL::Size(batch, m, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
if (plan.bluestein_n > 0) {
// Precomputed twiddle factors for Bluestein's
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
copies.push_back(w_q);
copies.push_back(w_k);
compute_encoder.set_input_array(w_q, 2); // w_q
compute_encoder.set_input_array(w_k, 3); // w_k
compute_encoder->setBytes(&n, sizeof(int), 4);
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
} else if (plan.rader_n > 1) {
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
copies.push_back(b_q);
copies.push_back(g_q);
copies.push_back(g_minus_q);
compute_encoder.set_input_array(b_q, 2);
compute_encoder.set_input_array(g_q, 3);
compute_encoder.set_input_array(g_minus_q, 4);
compute_encoder->setBytes(&n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
} else if (four_step_params.required) {
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
} else {
compute_encoder->setBytes(&n, sizeof(int), 2);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
}
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
auto grid_dims =
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
bool inplace,
const Stream& s) {
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
}
void nd_fft_op(
const array& in,
array& out,
const std::vector<size_t>& axes,
bool inverse,
bool real,
const Stream& s) {
// Perform ND FFT on GPU as a series of 1D FFTs
auto temp_shape = inverse ? in.shape() : out.shape();
array temp1(temp_shape, complex64, nullptr, {});
array temp2(temp_shape, complex64, nullptr, {});
std::vector<array> temp_arrs = {temp1, temp2};
for (int i = axes.size() - 1; i >= 0; i--) {
int reverse_index = axes.size() - i - 1;
// For 5D and above, we don't want to reallocate our two temporary arrays
bool inplace = reverse_index >= 3 && i != 0;
// Opposite order for fft vs ifft
int index = inverse ? reverse_index : i;
size_t axis = axes[index];
// Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis.
bool step_real = (real && index == axes.size() - 1);
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
}
std::vector<array> copies = {temp1, temp2};
auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& in = inputs[0];
if (axes_.size() > 1) {
nd_fft_op(in, out, axes_, inverse_, real_, s);
} else {
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
}
}
} // namespace mlx::core

View File

@@ -1,24 +1,35 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <cassert>
#include <numeric>
#include <sstream>
#include <fmt/format.h>
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
namespace {
constexpr int METAL_MAX_INDEX_ARRAYS = 20;
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
} // namespace
std::pair<std::string, std::string> make_index_args(
const std::string& idx_type,
int nidx) {
std::ostringstream idx_args;
std::ostringstream idx_arr;
for (int i = 0; i < nidx; ++i) {
idx_args << fmt::format(
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
idx_arr << fmt::format("idx{0}", i);
if (i < nidx - 1) {
idx_args << "\n";
idx_arr << ",";
}
}
return {idx_args.str(), idx_arr.str()};
}
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& src = inputs[0];
@@ -42,15 +53,41 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim();
std::ostringstream kname;
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
if (idx_ndim <= 1) {
kname << "_" << idx_ndim;
{
std::ostringstream kname;
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
<< "_" << idx_ndim;
lib_name = kname.str();
kernel_name = lib_name;
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gather();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations
kernel_source << fmt::format(
gather_kernels,
type_to_name(out) + idx_type_name,
out_type_str,
idx_type_str,
nidx,
idx_args,
idx_arr,
idx_ndim);
lib = d.get_library(lib_name, kernel_source.str());
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder->setComputePipelineState(kernel);
size_t slice_size = 1;
@@ -102,8 +139,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
@@ -139,10 +176,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& d = metal::device(s.device);
// Get kernel name
std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1);
@@ -159,32 +192,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
index_nd1_specialization &= inputs[i].flags().row_contiguous;
}
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string op_name;
switch (reduce_type_) {
case Scatter::None:
kname << "_none";
op_name = "none";
break;
case Scatter::Sum:
kname << "_sum";
op_name = "sum";
break;
case Scatter::Prod:
kname << "_prod";
op_name = "prod";
break;
case Scatter::Max:
kname << "_max";
op_name = "max";
break;
case Scatter::Min:
kname << "_min";
op_name = "min";
break;
}
kname << "_" << nidx;
{
std::ostringstream kname;
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
kname << "_" << op_name << "_" << nidx;
lib_name = kname.str();
kernel_name = kname.str();
}
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< metal::scatter();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool";
std::string op_type;
switch (reduce_type_) {
case Scatter::None:
op_type = "None";
break;
case Scatter::Sum:
op_type = "Sum<{0}>";
break;
case Scatter::Prod:
op_type = "Prod<{0}>";
break;
case Scatter::Max:
op_type = "Max<{0}>";
break;
case Scatter::Min:
op_type = "Min<{0}>";
break;
}
if (reduce_type_ != Scatter::None) {
op_type = fmt::format(op_type, out_type_str);
}
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source << fmt::format(
scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str,
idx_type_str,
op_type,
nidx,
idx_args,
idx_arr);
lib = d.get_library(lib_name, kernel_source.str());
}
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
auto kernel = d.get_kernel(kernel_name, lib);
auto& upd = inputs.back();
size_t nthreads = upd.size();
@@ -209,8 +296,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid
@@ -279,8 +366,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
compute_encoder.set_input_array(inputs[i], 20 + i);
for (int i = 0; i < nidx; ++i) {
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
}
// Launch grid

View File

@@ -0,0 +1,9 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view arange_kernels = R"(
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
constant const {1}& start,
constant const {1}& step,
device {1}* out,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,87 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,98 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_two_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,100 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view copy_kernels = R"(
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg4_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 4>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
copy_g_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg5_{0}")]] [[kernel]] void
copy_gg_nd<{1}, {2}, 5>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg1_{0}")]] [[kernel]] void
copy_gg_nd1<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]);
template [[host_name("gg2_{0}")]] [[kernel]] void
copy_gg_nd2<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]);
template [[host_name("gg3_{0}")]] [[kernel]] void
copy_gg_nd3<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
device const {1}* src [[buffer(0)]],
device {2}* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,53 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view rader_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
rader_fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
const device float2* raders_b_q [[buffer(2)]],
const device short* raders_g_q [[buffer(3)]],
const device short* raders_g_minus_q [[buffer(4)]],
constant const int& n,
constant const int& batch_size,
constant const int& rader_n,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view bluestein_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
bluestein_fft<{tg_mem_size}, {in_T}, {out_T}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
const device float2* w_q [[buffer(2)]],
const device float2* w_k [[buffer(3)]],
constant const int& length,
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";
constexpr std::string_view four_step_fft_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
four_step_fft<{tg_mem_size}, {in_T}, {out_T}, {step}, {real}>(
const device {in_T}* in [[buffer(0)]],
device {out_T}* out [[buffer(1)]],
constant const int& n1,
constant const int& n2,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,35 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
namespace mlx::core::metal {
const char* utils();
const char* binary_ops();
const char* unary_ops();
const char* ternary_ops();
const char* reduce_utils();
const char* gather();
const char* scatter();
const char* arange();
const char* unary();
const char* binary();
const char* binary_two();
const char* copy();
const char* fft();
const char* ternary();
const char* scan();
const char* softmax();
const char* sort();
const char* reduce();
const char* gemm();
const char* steel_gemm_fused();
const char* steel_gemm_masked();
const char* steel_gemm_splitk();
const char* conv();
const char* steel_conv();
const char* steel_conv_general();
} // namespace mlx::core::metal

View File

@@ -0,0 +1,81 @@
// Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}(
const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant size_t* idx_strides [[buffer(8)]],
const constant int& idx_ndim [[buffer(9)]],
{4}
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {{
Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}>(
src,
out,
src_shape,
src_strides,
src_ndim,
slice_sizes,
axes,
idxs,
index,
grid_dim);
}}
)";
constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter_1d_index{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
}}
[[kernel]] void scatter{0}_{4}(
const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]],
const constant size_t* idx_strides [[buffer(12)]],
const constant int& idx_ndim [[buffer(13)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}>(
updates,
out,
upd_shape,
upd_strides,
upd_ndim,
upd_size,
out_shape,
out_strides,
out_ndim,
axes,
idxs,
gid);
}}
)";

View File

@@ -0,0 +1,168 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view reduce_init_kernels = R"(
[[kernel]] void {0}(
device {1}* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {{
out[tid] = {2}<{1}>::init;
}}
)";
constexpr std::string_view reduce_kernels = R"(
template [[host_name("all_{0}")]] [[kernel]] void
all_reduce<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("colGeneral_{0}")]] [[kernel]] void
col_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
row_reduce_general<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device mlx_atomic<{2}>* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";
constexpr std::string_view reduce_non_atomic_kernels = R"(
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const device size_t& in_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint grid_size [[threads_per_grid]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint thread_group_id [[threadgroup_position_in_grid]]);
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
threadgroup {2}* local_data [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 gid [[thread_position_in_grid]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]]);
template [[host_name("colSmall_{0}")]] [[kernel]] void
col_reduce_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]],
const constant size_t& out_size [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
const constant size_t& non_col_reductions [[buffer(8)]],
const constant int* non_col_shapes [[buffer(9)]],
const constant size_t* non_col_strides [[buffer(10)]],
const constant int& non_col_ndim [[buffer(11)]],
uint tid [[thread_position_in_grid]]);
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint lid [[thread_position_in_grid]]);
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]],
const constant size_t& non_row_reductions [[buffer(4)]],
const constant int* shape [[buffer(5)]],
const constant size_t* strides [[buffer(6)]],
const constant int& ndim [[buffer(7)]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]],
uint3 gsize [[threads_per_grid]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_per_group [[simdgroups_per_threadgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -0,0 +1,26 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view scan_kernels = R"(
template [[host_name("contig_{0}")]] [[kernel]] void
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("strided_{0}")]] [[kernel]] void
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
const device {1}* in [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant size_t& axis_size [[buffer(2)]],
const constant size_t& stride [[buffer(3)]],
uint2 gid [[thread_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]);
)";

View File

@@ -0,0 +1,23 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view softmax_kernels = R"(
template [[host_name("block_{0}")]] [[kernel]] void
softmax_single_row<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[thread_position_in_grid]],
uint _lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
template [[host_name("looped_{0}")]] [[kernel]] void
softmax_looped<{1}, {2}>(
const device {1}* in,
device {1}* out,
constant int& axis_size,
uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
)";

View File

@@ -0,0 +1,81 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view block_sort_kernels = R"(
template [[host_name("carg_{0}")]] [[kernel]] void
block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("ncarg_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("c_{0}")]] [[kernel]] void
block_sort<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("nc_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view multiblock_sort_kernels = R"(
template [[host_name("sort_{0}")]] [[kernel]] void
mb_block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {1}* out_vals [[buffer(1)]],
device {2}* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("partition_{0}")]] [[kernel]] void
mb_block_partition<{1}, {2}, true, {3}, {4}>(
device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals [[buffer(1)]],
const device {2}* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]);
template [[host_name("merge_{0}")]] [[kernel]] void
mb_block_merge<{1}, {2}, true, {3}, {4}>(
const device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals_in [[buffer(1)]],
const device {2}* dev_idxs_in [[buffer(2)]],
device {1}* dev_vals_out [[buffer(3)]],
device {2}* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";

View File

@@ -0,0 +1,32 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view steel_conv_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";
constexpr std::string_view steel_conv_general_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* C [[buffer(2)]],
const constant MLXConvParams<2>* params [[buffer(3)]],
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";

View File

@@ -0,0 +1,106 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view steel_gemm_fused_kernels = R"(
template [[host_name("{name}")]]
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
const device {itype} *A [[buffer(0)]],
const device {itype} *B [[buffer(1)]],
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
device {itype} *D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_masked_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
block_masked_gemm<
{itype},
{outmasktype},
{opmasktype},
{bm},
{bn},
{bk},
{wm},
{wn},
{trans_a},
{trans_b},
{mn_aligned},
{k_aligned}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {itype}* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
const device {outmasktype}* out_mask [[buffer(10)]],
const device {opmasktype}* lhs_mask [[buffer(11)]],
const device {opmasktype}* rhs_mask [[buffer(12)]],
const constant int* mask_strides [[buffer(13)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_splitk_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk<
{itype},
{otype},
{bm},
{bn},
{bk},
{wm},
{wn},
{trans_a},
{trans_b},
{mn_aligned},
{k_aligned}>(
const device {itype}* A [[buffer(0)]],
const device {itype}* B [[buffer(1)]],
device {otype}* C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk_accum<{atype}, {otype}>(
const device {atype}* C_split [[buffer(0)]],
device {otype}* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]);
)";
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
template [[host_name("{name}")]] [[kernel]] void
gemm_splitk_accum_axpby<{atype}, {otype}>(
const device {atype}* C_split [[buffer(0)]],
device {otype}* D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device {otype}* C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,80 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view ternary_kernels = R"(
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void
ternary_g_nd1<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void
ternary_g_nd2<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void
ternary_g_nd3<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 4>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
constant const size_t c_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 5>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
constant const size_t c_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,16 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view unary_kernels = R"(
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
device const {1}* in,
device {1}* out,
uint index [[thread_position_in_grid]]);
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
device const {1}* in,
device {1}* out,
device const int* in_shape,
device const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -0,0 +1,540 @@
// Copyright © 2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/fft.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/jit/ternary.h"
#include "mlx/backend/metal/jit/unary.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
using namespace fmt::literals;
namespace mlx::core {
std::string op_name(const array& arr) {
std::ostringstream op_t;
arr.primitive().print(op_t);
return op_t.str();
}
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source
<< metal::utils() << metal::arange()
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< fmt::format(
unary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
<< fmt::format(
binary_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two()
<< fmt::format(
binary_two_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
<< fmt::format(
ternary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::copy()
<< fmt::format(
copy_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
bool precise,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::softmax()
<< fmt::format(
softmax_kernels,
lib_name,
get_type_string(out.dtype()),
get_type_string(precise ? float32 : out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
bool reverse,
bool inclusive,
const std::string& reduce_type,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_name = "Cum" + reduce_type;
op_name[3] = toupper(op_name[3]);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
scan_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name,
inclusive,
reverse);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
block_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
bn,
tn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_mb_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& idx,
int bn,
int tn) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
multiblock_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
bn,
tn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out) {
auto lib = d.get_library(kernel_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils()
<< fmt::format(
reduce_init_kernels,
kernel_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(kernel_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& op_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
<< fmt::format(
non_atomic ? reduce_non_atomic_kernels
: reduce_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_type);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_fused()
<< fmt::format(
steel_gemm_fused_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
steel_gemm_splitk_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool axbpy) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_splitk()
<< fmt::format(
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
: steel_gemm_splitk_accum_kernels,
"name"_a = lib_name,
"atype"_a = get_type_string(in.dtype()),
"otype"_a = get_type_string(out.dtype()));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
: "nomask_t";
auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemm()
<< metal::steel_gemm_masked()
<< fmt::format(
steel_gemm_masked_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outmasktype"_a = out_mask_type,
"opmasktype"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"trans_a"_a = transpose_a,
"trans_b"_a = transpose_b,
"mn_aligned"_a = mn_aligned,
"k_aligned"_a = k_aligned);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn,
int n_channel_specialization,
bool small_filter) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
<< fmt::format(
steel_conv_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn,
"n_channels"_a = n_channel_specialization,
"small_filter"_a = small_filter);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::conv()
<< metal::steel_conv_general()
<< fmt::format(
steel_conv_general_kernels,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"bm"_a = bm,
"bn"_a = bn,
"bk"_a = bk,
"wm"_a = wm,
"wn"_a = wn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const int tg_mem_size,
const std::string& in_type,
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
std::string kernel_string;
if (lib_name.find("bluestein") != std::string::npos) {
kernel_string = bluestein_fft_kernel;
} else if (lib_name.find("rader") != std::string::npos) {
kernel_string = rader_fft_kernel;
} else if (lib_name.find("four_step") != std::string::npos) {
kernel_string = four_step_fft_kernel;
} else {
kernel_string = fft_kernel;
}
kernel_source << metal::fft();
if (lib_name.find("four_step") != std::string::npos) {
kernel_source << fmt::format(
kernel_string,
"name"_a = lib_name,
"tg_mem_size"_a = tg_mem_size,
"in_T"_a = in_type,
"out_T"_a = out_type,
"step"_a = step,
"real"_a = real);
} else {
kernel_source << fmt::format(
kernel_string,
"name"_a = lib_name,
"tg_mem_size"_a = tg_mem_size,
"in_T"_a = in_type,
"out_T"_a = out_type);
}
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
} // namespace mlx::core

169
mlx/backend/metal/kernels.h Normal file
View File

@@ -0,0 +1,169 @@
// Copyright © 2024 Apple Inc.
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
namespace mlx::core {
MTL::ComputePipelineState* get_arange_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_softmax_kernel(
metal::Device& d,
const std::string& kernel_name,
bool precise,
const array& out);
MTL::ComputePipelineState* get_scan_kernel(
metal::Device& d,
const std::string& kernel_name,
bool reverse,
bool inclusive,
const std::string& reduce_type,
const array& in,
const array& out);
MTL::ComputePipelineState* get_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
int bn,
int tn);
MTL::ComputePipelineState* get_mb_sort_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& idx,
int bn,
int tn);
MTL::ComputePipelineState* get_reduce_init_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out);
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& op_name,
const array& in,
const array& out);
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn);
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned);
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& in,
const array& out,
bool axbpy);
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_a,
bool transpose_b,
int bm,
int bn,
int bk,
int wm,
int wn,
bool mn_aligned,
bool k_aligned);
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn,
int n_channel_specialization,
bool small_filter);
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
int bm,
int bn,
int bk,
int wm,
int wn);
MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const int tg_mem_size,
const std::string& in_type,
const std::string& out_type,
int step,
bool real,
const metal::MTLFCList& func_consts);
} // namespace mlx::core

View File

@@ -1,26 +1,17 @@
set(
HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
bf16.h
bf16_math.h
complex.h
defines.h
utils.h
steel/conv/params.h
)
set(
KERNELS
"arange"
"arg_reduce"
"binary"
"binary_two"
"conv"
"copy"
"fft"
"gemv"
"quantized"
@@ -28,15 +19,48 @@ set(
"rms_norm"
"layer_norm"
"rope"
"scan"
"scaled_dot_product_attention"
)
if (NOT MLX_METAL_JIT)
set(
KERNELS
${KERNELS}
"arange"
"binary"
"binary_two"
"unary"
"ternary"
"copy"
"softmax"
"sort"
"ternary"
"unary"
"gather"
"scatter"
"scan"
"reduce"
)
set(
HEADERS
${HEADERS}
atomic.h
arange.h
unary_ops.h
unary.h
binary_ops.h
binary.h
ternary.h
copy.h
fft.h
fft/radix.h
fft/readwrite.h
softmax.h
sort.h
scan.h
reduction/ops.h
reduction/reduce_init.h
reduction/reduce_all.h
reduction/reduce_col.h
reduction/reduce_row.h
)
endif()
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
@@ -68,23 +92,40 @@ foreach(KERNEL ${KERNELS})
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
endforeach()
file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal)
file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h)
foreach(KERNEL ${STEEL_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal)
file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h)
foreach(KERNEL ${REDUCE_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
if (NOT MLX_METAL_JIT)
set(
STEEL_KERNELS
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
)
set(
STEEL_HEADERS
steel/defines.h
steel/utils.h
steel/conv/conv.h
steel/conv/loader.h
steel/conv/loaders/loader_channel_l.h
steel/conv/loaders/loader_channel_n.h
steel/conv/loaders/loader_general.h
steel/conv/kernels/steel_conv.h
steel/conv/kernels/steel_conv_general.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h
)
foreach(KERNEL ${STEEL_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib

View File

@@ -0,0 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
template <typename T>
[[kernel]] void arange(
constant const T& start,
constant const T& step,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = start + index * step;
}

View File

@@ -1,15 +1,8 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
template <typename T>
[[kernel]] void arange(
constant const T& start,
constant const T& step,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = start + index * step;
}
#include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
@@ -18,7 +11,6 @@ template <typename T>
device type* out, \
uint index [[thread_position_in_grid]]);
// clang-format off
instantiate_arange(uint8, uint8_t)
instantiate_arange(uint16, uint16_t)
instantiate_arange(uint32, uint32_t)

View File

@@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/utils.h"
@@ -194,4 +193,4 @@ instantiate_arg_reduce(int32, int32_t)
instantiate_arg_reduce(int64, int64_t)
instantiate_arg_reduce(float16, half)
instantiate_arg_reduce(float32, float)
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on

View File

@@ -4,7 +4,6 @@
#include <metal_atomic>
#include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
using namespace metal;

View File

@@ -6,9 +6,7 @@
using namespace metal;
// No support for less than metal 3.0
// anything greater has native bfloat
#ifndef METAL_3_0
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;

View File

@@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#ifndef METAL_3_0
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)

View File

@@ -1,273 +1,113 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2024 Apple Inc.
#pragma once
template <typename T, typename U, typename Op>
[[kernel]] void binary_ss(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
#include <metal_integer>
#include <metal_math>
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
}
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
}
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
}
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
device U* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf)
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
}
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return metal::precise::atan2(y, x);
}
};
template <typename T, typename U, typename Op>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
device U* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}

View File

@@ -1,130 +1,24 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2024 Apple Inc.
#include <metal_integer>
#include <metal_math>
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_ss(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_sv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_vs(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_vv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd1(
device const T* a,
device const T* b,
device U* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd2(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g_nd3(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_op_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_op_g(
device const T* a,
device const T* b,
device U* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op, bopt) \
template \
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
[[host_name(name)]] [[kernel]] void binary_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op, dims>( \
template [[host_name("g" #dims name)]] [[kernel]] void \
binary_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -135,16 +29,16 @@ template <typename T, typename U, typename Op>
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name(name "_1")]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op>( \
template [[host_name("g1" name)]] [[kernel]] void \
binary_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op>( \
template [[host_name("g2" name)]] [[kernel]] void \
binary_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -152,8 +46,8 @@ template <typename T, typename U, typename Op>
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op>( \
template [[host_name("g3" name)]] [[kernel]] void \
binary_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -162,30 +56,28 @@ template <typename T, typename U, typename Op>
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name("gn" name)]] [[kernel]] void binary_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
instantiate_binary_g(#name #tname, itype, otype, op) \
instantiate_binary_g_nd(#name #tname, itype, otype, op)
// clang-format off
#define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
@@ -194,22 +86,19 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
instantiate_binary_all(name, int64, int64_t, int64_t, op)
// clang-format off
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
// clang-format off
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op) // clang-format on
instantiate_binary_float(name, op)
// clang-format off
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
@@ -223,9 +112,8 @@ template <typename T, typename U, typename Op>
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
instantiate_binary_all(name, complex64, complex64_t, bool, op)
// clang-format off
instantiate_binary_types(add, Add)
instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal)

View File

@@ -0,0 +1,296 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_integer>
#include <metal_math>
struct Add {
template <typename T>
T operator()(T x, T y) {
return x + y;
}
};
struct FloorDivide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
};
struct Divide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
struct Equal {
template <typename T>
bool operator()(T x, T y) {
return x == y;
}
};
struct NaNEqual {
template <typename T>
bool operator()(T x, T y) {
return x == y || (metal::isnan(x) && metal::isnan(y));
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x == y ||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
metal::isnan(y.imag)) ||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
}
};
struct Greater {
template <typename T>
bool operator()(T x, T y) {
return x > y;
}
};
struct GreaterEqual {
template <typename T>
bool operator()(T x, T y) {
return x >= y;
}
};
struct Less {
template <typename T>
bool operator()(T x, T y) {
return x < y;
}
};
struct LessEqual {
template <typename T>
bool operator()(T x, T y) {
return x <= y;
}
};
struct LogAddExp {
template <typename T>
T operator()(T x, T y) {
if (metal::isnan(x) || metal::isnan(y)) {
return metal::numeric_limits<T>::quiet_NaN();
}
constexpr T inf = metal::numeric_limits<T>::infinity();
T maxval = metal::max(x, y);
T minval = metal::min(x, y);
return (minval == -inf || maxval == inf)
? maxval
: (maxval + log1p(metal::exp(minval - maxval)));
};
};
struct Maximum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::max(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x > y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x > y ? x : y;
}
};
struct Minimum {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
return metal::min(x, y);
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
if (metal::isnan(x)) {
return x;
}
return x < y ? x : y;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
return x;
}
return x < y ? x : y;
}
};
struct Multiply {
template <typename T>
T operator()(T x, T y) {
return x * y;
}
};
struct NotEqual {
template <typename T>
bool operator()(T x, T y) {
return x != y;
}
template <>
bool operator()(complex64_t x, complex64_t y) {
return x.real != y.real || x.imag != y.imag;
}
};
struct Power {
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
return metal::pow(base, exp);
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
T res = 1;
while (exp) {
if (exp & 1) {
res *= base;
}
exp >>= 1;
base *= base;
}
return res;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
auto x_theta = metal::atan(x.imag / x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
auto phase = y.imag * x_ln_r + y.real * x_theta;
return {mag * metal::cos(phase), mag * metal::sin(phase)};
}
};
struct Subtract {
template <typename T>
T operator()(T x, T y) {
return x - y;
}
};
struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
};
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
};
};
struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return metal::precise::atan2(y, x);
}
};
struct DivMod {
template <typename T>
metal::array<T, 2> operator()(T x, T y) {
return {FloorDivide{}(x, y), Remainder{}(x, y)};
};
};

View File

@@ -0,0 +1,140 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U, typename Op>
[[kernel]] void binary_ss(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[0], b[0]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[0], b[index]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[0]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[index]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0];
d[out_idx] = out[1];
}

View File

@@ -1,212 +1,24 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2024 Apple Inc.
#include <metal_integer>
#include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
struct FloorDivide {
template <typename T>
T operator()(T x, T y) {
return x / y;
}
template <>
float operator()(float x, float y) {
return trunc(x / y);
}
template <>
half operator()(half x, half y) {
return trunc(x / y);
}
template <>
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
return trunc(x / y);
}
};
struct Remainder {
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
operator()(T x, T y) {
return x % y;
}
template <typename T>
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
operator()(T x, T y) {
auto r = x % y;
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <typename T>
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
T r = fmod(x, y);
if (r != 0 && (r < 0 != y < 0)) {
r += y;
}
return r;
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
return x % y;
}
};
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_s2s(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[0]);
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_ss(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[0]);
d[index] = Op2()(a[0], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[0], b[index]);
d[index] = Op2()(a[0], b[index]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[index], b[0]);
d[index] = Op2()(a[index], b[0]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint index [[thread_position_in_grid]]) {
c[index] = Op1()(a[index], b[index]);
d[index] = Op2()(a[index], b[index]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd1(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op1()(a[a_idx], b[b_idx]);
d[index] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g_nd3(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op1, typename Op2, int DIM>
[[kernel]] void binary_op_g_nd(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op1, typename Op2>
[[kernel]] void binary_op_g(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
}
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
#define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] [[kernel]] void \
binary_op_##bopt<itype, otype, op1, op2>( \
binary_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \
binary_op_g_nd<itype, otype, op1, op2, dims>( \
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name("g" #dims name)]] [[kernel]] void \
binary_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -217,10 +29,9 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
template [[host_name(name "_1")]] [[kernel]] void \
binary_op_g_nd1<itype, otype, op1, op2>( \
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name("g1" name)]] [[kernel]] void \
binary_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -228,8 +39,8 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void \
binary_op_g_nd2<itype, otype, op1, op2>( \
template [[host_name("g2" name)]] [[kernel]] void \
binary_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -238,8 +49,8 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void \
binary_op_g_nd3<itype, otype, op1, op2>( \
template [[host_name("g3" name)]] [[kernel]] void \
binary_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -248,12 +59,12 @@ template <typename T, typename U, typename Op1, typename Op2>
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_g(name, itype, otype, op1, op2) \
template [[host_name(name)]] [[kernel]] void \
binary_op_g<itype, otype, op2, op2>( \
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name("gn" name)]] [[kernel]] void \
binary_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
@@ -265,33 +76,30 @@ template <typename T, typename U, typename Op1, typename Op2>
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g(#name #tname, itype, otype, op) \
instantiate_binary_g_nd(#name #tname, itype, otype, op)
// clang-format off
#define instantiate_binary_float(name, op1, op2) \
instantiate_binary_all(name, float16, half, half, op1, op2) \
instantiate_binary_all(name, float32, float, float, op1, op2) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
// clang-format off
#define instantiate_binary_types(name, op1, op2) \
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
instantiate_binary_float(name, op1, op2)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on
instantiate_binary_types(divmod, DivMod) // clang-format on

View File

@@ -1,7 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/binary.h"
#include "mlx/backend/metal/kernels/ternary.h"
#include "mlx/backend/metal/kernels/unary.h"
typedef half float16_t;

View File

@@ -109,6 +109,7 @@ template <typename T, int N>
bool valid = n < params->N;
// Unroll dimensions
int kernel_stride = 1;
for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]);
@@ -125,7 +126,8 @@ template <typename T, int N>
oS /= params->oS[i];
wS /= params->wS[i];
out += ws_ * params->str[i];
out += ws_ * kernel_stride;
kernel_stride *= params->wS[i];
}
if (valid) {
@@ -648,4 +650,4 @@ winograd_conv_2d_output_transform(
// clang-format off
instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(float16, half); // clang-format on
instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@@ -0,0 +1,144 @@
// Copyright © 2024 Apple Inc.
template <typename T, typename U>
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}

View File

@@ -1,150 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
template <typename T, typename U>
[[kernel]] void copy_s(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_g_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_g(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
int64_t dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U, int DIM>
[[kernel]] void copy_gg_nd(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
template <typename T, typename U>
[[kernel]] void copy_gg(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
constant const int* src_shape [[buffer(2)]],
constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
dst[dst_idx] = static_cast<U>(src[src_idx]);
}
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
@@ -152,92 +11,90 @@ template <typename T, typename U>
device otype* dst [[buffer(1)]], \
uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name(name "_" #dims)]] [[kernel]] void \
copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_" #dims)]] [[kernel]] void \
copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name("g" #dims "_" name)]] [[kernel]] void \
copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name "_1")]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g" name "_2")]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("g" name "_3")]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg1_" name )]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("gg2_" name)]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("gg3_" name)]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
#define instantiate_copy_g(name, itype, otype) \
template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]);
// clang-format off
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("scopy" #tname, itype, otype, s) \
instantiate_copy("vcopy" #tname, itype, otype, v) \
instantiate_copy_g("gcopy" #tname, itype, otype) \
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
#define instantiate_copy_all(tname, itype, otype) \
instantiate_copy("s_copy" #tname, itype, otype, s) \
instantiate_copy("v_copy" #tname, itype, otype, v) \
instantiate_copy_g("copy" #tname, itype, otype) \
instantiate_copy_g_nd("copy" #tname, itype, otype)
// clang-format off
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \
instantiate_copy_all(itname ##uint8, itype, uint8_t) \

View File

@@ -2,17 +2,14 @@
#pragma once
#ifdef __METAL__
#if defined __METAL__ || defined MLX_METAL_JIT
#define MTL_CONST constant
#else
#define MTL_CONST
#endif
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;

View File

@@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_math>
/*
@@ -67,4 +66,4 @@ float erfinv(float a) {
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
}
return a * p;
}
}

View File

@@ -0,0 +1,486 @@
// Copyright © 2024 Apple Inc.
// Metal FFT using Stockham's algorithm
//
// References:
// - VkFFT (https://github.com/DTolm/VkFFT)
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
#include <metal_common>
#include "mlx/backend/metal/kernels/fft/radix.h"
#include "mlx/backend/metal/kernels/fft/readwrite.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
using namespace metal;
#define MAX_RADIX 13
// Reached when elems_per_thread_ = 6, max_radix = 13
// and some threads have to do 3 radix 6s requiring 18 float2s.
#define MAX_OUTPUT_SIZE 18
// Specialize for a particular value of N at runtime
STEEL_CONST bool inv_ [[function_constant(0)]];
STEEL_CONST bool is_power_of_2_ [[function_constant(1)]];
STEEL_CONST int elems_per_thread_ [[function_constant(2)]];
// rader_m = n / rader_n
STEEL_CONST int rader_m_ [[function_constant(3)]];
// Stockham steps
STEEL_CONST int radix_13_steps_ [[function_constant(4)]];
STEEL_CONST int radix_11_steps_ [[function_constant(5)]];
STEEL_CONST int radix_8_steps_ [[function_constant(6)]];
STEEL_CONST int radix_7_steps_ [[function_constant(7)]];
STEEL_CONST int radix_6_steps_ [[function_constant(8)]];
STEEL_CONST int radix_5_steps_ [[function_constant(9)]];
STEEL_CONST int radix_4_steps_ [[function_constant(10)]];
STEEL_CONST int radix_3_steps_ [[function_constant(11)]];
STEEL_CONST int radix_2_steps_ [[function_constant(12)]];
// Rader steps
STEEL_CONST int rader_13_steps_ [[function_constant(13)]];
STEEL_CONST int rader_11_steps_ [[function_constant(14)]];
STEEL_CONST int rader_8_steps_ [[function_constant(15)]];
STEEL_CONST int rader_7_steps_ [[function_constant(16)]];
STEEL_CONST int rader_6_steps_ [[function_constant(17)]];
STEEL_CONST int rader_5_steps_ [[function_constant(18)]];
STEEL_CONST int rader_4_steps_ [[function_constant(19)]];
STEEL_CONST int rader_3_steps_ [[function_constant(20)]];
STEEL_CONST int rader_2_steps_ [[function_constant(21)]];
// See "radix.h" for radix codelets
typedef void (*RadixFunc)(thread float2*, thread float2*);
// Perform a single radix n butterfly with appropriate twiddles
template <int radix, RadixFunc radix_func>
METAL_FUNC void radix_butterfly(
int i,
int p,
thread float2* x,
thread short* indices,
thread float2* y) {
// i: the index in the overall DFT that we're processing.
// p: the size of the DFTs we're merging at this step.
// m: how many threads are working on this DFT.
int k, j;
// Use faster bitwise operations when working with powers of two
constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
if (radix_p_2 && is_power_of_2_) {
constexpr short power = __builtin_ctz(radix);
k = i & (p - 1);
j = ((i - k) << power) + k;
} else {
k = i % p;
j = (i / p) * radix * p + k;
}
// Apply twiddles
if (p > 1) {
float2 twiddle_1 = get_twiddle(k, radix * p);
float2 twiddle = twiddle_1;
x[1] = complex_mul(x[1], twiddle);
STEEL_PRAGMA_UNROLL
for (int t = 2; t < radix; t++) {
twiddle = complex_mul(twiddle, twiddle_1);
x[t] = complex_mul(x[t], twiddle);
}
}
radix_func(x, y);
STEEL_PRAGMA_UNROLL
for (int t = 0; t < radix; t++) {
indices[t] = j + t * p;
}
}
// Perform all the radix steps required for a
// particular radix size n.
template <int radix, RadixFunc radix_func>
METAL_FUNC void radix_n_steps(
int i,
thread int* p,
int m,
int n,
int num_steps,
thread float2* inputs,
thread short* indices,
thread float2* values,
threadgroup float2* buf) {
int m_r = n / radix;
// When combining different sized radices, we have to do
// multiple butterflies in a single thread.
// E.g. n = 28 = 4 * 7
// 4 threads, 7 elems_per_thread
// All threads do 1 radix7 butterfly.
// 3 threads do 2 radix4 butterflies.
// 1 thread does 1 radix4 butterfly.
int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix;
int index = 0;
int r_index = 0;
for (int s = 0; s < num_steps; s++) {
for (int t = 0; t < max_radices_per_thread; t++) {
index = i + t * m;
if (index < m_r) {
for (int r = 0; r < radix; r++) {
inputs[r] = buf[index + r * m_r];
}
radix_butterfly<radix, radix_func>(
index, *p, inputs, indices + t * radix, values + t * radix);
}
}
// Wait until all threads have read their inputs into thread local mem
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int t = 0; t < max_radices_per_thread; t++) {
index = i + t * m;
if (index < m_r) {
for (int r = 0; r < radix; r++) {
r_index = t * radix + r;
buf[indices[r_index]] = values[r_index];
}
}
}
// Wait until all threads have written back to threadgroup mem
threadgroup_barrier(mem_flags::mem_threadgroup);
*p *= radix;
}
}
#define RADIX_STEP(radix, radix_func, num_steps) \
radix_n_steps<radix, radix_func>( \
fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
template <bool rader = false>
METAL_FUNC void
perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) {
float2 inputs[MAX_RADIX];
short indices[MAX_OUTPUT_SIZE];
float2 values[MAX_OUTPUT_SIZE];
RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_);
RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_);
RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_);
RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_);
RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_);
RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_);
RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_);
RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_);
RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_);
}
// Each FFT is computed entirely in shared GPU memory.
//
// N is decomposed into radix-n DFTs:
// e.g. 128 = 2 * 4 * 4 * 4
template <int tg_mem_size, typename in_T, typename out_T>
[[kernel]] void fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
threadgroup float2 shared_in[tg_mem_size];
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load();
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
int fft_idx = elem.z; // Thread index in DFT
int m = grid.z; // Threads per DFT
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
threadgroup float2* buf = &shared_in[tg_idx];
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write();
}
template <int tg_mem_size, typename in_T, typename out_T>
[[kernel]] void rader_fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
const device float2* raders_b_q [[buffer(2)]],
const device short* raders_g_q [[buffer(3)]],
const device short* raders_g_minus_q [[buffer(4)]],
constant const int& n,
constant const int& batch_size,
constant const int& rader_n,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Use Rader's algorithm to compute fast FFTs
// when a prime factor `p` of `n` is greater than 13 but
// has `p - 1` Stockham decomposable into to prime factors <= 13.
//
// E.g. n = 102
// = 2 * 3 * 17
// . = 2 * 3 * RADER(16)
// . = 2 * 3 * RADER(4 * 4)
//
// In numpy:
// x_perm = x[g_q]
// y = np.fft.fft(x_perm) * b_q
// z = np.fft.ifft(y) + x[0]
// out = z[g_minus_q]
// out[0] = x[1:].sum()
//
// Where the g_q and g_minus_q are permutations formed
// by the group under multiplicative modulo N using the
// primitive root of N and b_q is a constant.
// See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm
//
// Rader's uses fewer operations than Bluestein's and so
// is more accurate. It's also faster in most cases.
threadgroup float2 shared_in[tg_mem_size];
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load();
threadgroup_barrier(mem_flags::mem_threadgroup);
// The number of the threads we're using for each DFT
int m = grid.z;
int fft_idx = elem.z;
int tg_idx = elem.y * n;
threadgroup float2* buf = &shared_in[tg_idx];
// rader_m = n / rader_n;
int rader_m = rader_m_;
// We have to load two x_0s for each thread since sometimes
// elems_per_thread_ crosses a boundary.
// E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4
// 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
short x_0_index =
metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1);
float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]};
// Do the Rader permutation in shared memory
float2 temp[MAX_RADIX];
int max_index = n - rader_m - 1;
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
short g_q = raders_g_q[index / rader_m];
temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
buf[index + rader_m] = temp[e];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Rader FFT on x[rader_m:]
int p = 1;
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
// x_1 + ... + x_n is computed for us in the first FFT step so
// we save it in the first rader_m indices of the array for later.
int x_sum_index = metal::min(fft_idx, rader_m - 1);
buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)];
float2 inv = {1.0f, -1.0f};
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
short interleaved_index =
index / rader_m + (index % rader_m) * (rader_n - 1);
temp[e] = complex_mul(
buf[rader_m + interleaved_index],
raders_b_q[interleaved_index % (rader_n - 1)]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
buf[rader_m + index] = temp[e] * inv;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Rader IFFT on x[rader_m:]
p = 1;
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1);
short diff_index = index / (rader_n - 1) - x_0_index;
temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
}
// Use the sum of elements that was computed in the first FFT
float2 x_sum = buf[x_0_index] + x_0[0];
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
short g_q_index = index % (rader_n - 1);
short g_q = raders_g_minus_q[g_q_index];
short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
buf[out_index] = temp[e];
}
buf[x_0_index * rader_n] = x_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
p = rader_n;
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write();
}
template <int tg_mem_size, typename in_T, typename out_T>
[[kernel]] void bluestein_fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
const device float2* w_q [[buffer(2)]],
const device float2* w_k [[buffer(3)]],
constant const int& length,
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Computes arbitrary length FFTs with Bluestein's algorithm
//
// In numpy:
// bluestein_n = next_power_of_2(2*n - 1)
// out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q)
//
// Where w_k and w_q are precomputed on CPU in high precision as:
// w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2))
// w_q = np.fft.fft(1/w_k[-n:])
threadgroup float2 shared_in[tg_mem_size];
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load_padded(length, w_k);
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
int fft_idx = elem.z; // Thread index in DFT
int m = grid.z; // Threads per DFT
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
threadgroup float2* buf = &shared_in[tg_idx];
// fft
perform_fft(fft_idx, &p, m, n, buf);
float2 inv = float2(1.0f, -1.0f);
for (int t = 0; t < elems_per_thread_; t++) {
int index = fft_idx + t * m;
buf[index] = complex_mul(buf[index], w_q[index]) * inv;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ifft
p = 1;
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write_padded(length, w_k);
}
template <
int tg_mem_size,
typename in_T,
typename out_T,
int step,
bool real = false>
[[kernel]] void four_step_fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
constant const int& n1,
constant const int& n2,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Fast four step FFT implementation for powers of 2.
int overall_n = n1 * n2;
int n = step == 0 ? n1 : n2;
int stride = step == 0 ? n2 : n1;
// The number of the threads we're using for each DFT
int m = grid.z;
int fft_idx = elem.z;
threadgroup float2 shared_in[tg_mem_size];
threadgroup float2* buf = &shared_in[elem.y * n];
using read_writer_t = ReadWriter<in_T, out_T, step, real>;
read_writer_t read_writer = read_writer_t(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load_strided(stride, overall_n);
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write_strided(stride, overall_n);
}

View File

@@ -1,199 +1,84 @@
// Copyright © 2024 Apple Inc.
// Metal FFT using Stockham's algorithm
//
// References:
// - VkFFT (https://github.com/DTolm/VkFFT)
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
#include "mlx/backend/metal/kernels/fft.h"
#include <metal_common>
#include <metal_math>
#define instantiate_fft(tg_mem_size, in_T, out_T) \
template [[host_name("fft_mem_" #tg_mem_size "_" #in_T \
"_" #out_T)]] [[kernel]] void \
fft<tg_mem_size, in_T, out_T>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
constant const int& n, \
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#define instantiate_rader(tg_mem_size, in_T, out_T) \
template [[host_name("rader_fft_mem_" #tg_mem_size "_" #in_T \
"_" #out_T)]] [[kernel]] void \
rader_fft<tg_mem_size, in_T, out_T>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
const device float2* raders_b_q [[buffer(2)]], \
const device short* raders_g_q [[buffer(3)]], \
const device short* raders_g_minus_q [[buffer(4)]], \
constant const int& n, \
constant const int& batch_size, \
constant const int& rader_n, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
using namespace metal;
#define instantiate_bluestein(tg_mem_size, in_T, out_T) \
template [[host_name("bluestein_fft_mem_" #tg_mem_size "_" #in_T \
"_" #out_T)]] [[kernel]] void \
bluestein_fft<tg_mem_size, in_T, out_T>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
const device float2* w_q [[buffer(2)]], \
const device float2* w_k [[buffer(3)]], \
constant const int& length, \
constant const int& n, \
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
float2 complex_mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x - a.y * b.y;
c.y = a.x * b.y + a.y * b.x;
return c;
}
#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \
template [[host_name("four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T \
"_" #step "_" #real)]] [[kernel]] void \
four_step_fft<tg_mem_size, in_T, out_T, step, real>( \
const device in_T* in [[buffer(0)]], \
device out_T* out [[buffer(1)]], \
constant const int& n1, \
constant const int& n2, \
constant const int& batch_size, \
uint3 elem [[thread_position_in_grid]], \
uint3 grid [[threads_per_grid]]);
float2 get_twiddle(int k, int p) {
float theta = -1.0f * k * M_PI_F / (2 * p);
float2 twiddle;
twiddle.x = metal::fast::cos(theta);
twiddle.y = metal::fast::sin(theta);
return twiddle;
}
// single threaded radix2 implemetation
void radix2(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
// The index within this sub-DFT
int k = i & (p - 1);
float2 twiddle = get_twiddle(k, p);
float2 z = complex_mul(x_1, twiddle);
float2 y_0 = x_0 + z;
float2 y_1 = x_0 - z;
int j = (i << 1) - k;
write_buf[j] = y_0;
write_buf[j + p] = y_1;
}
// single threaded radix4 implemetation
void radix4(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
float2 x_2 = read_buf[i + 2 * m];
float2 x_3 = read_buf[i + 3 * m];
// The index within this sub-DFT
int k = i & (p - 1);
float2 twiddle = get_twiddle(k, p);
// e^a * e^b = e^(a + b)
float2 twiddle_2 = complex_mul(twiddle, twiddle);
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
x_1 = complex_mul(x_1, twiddle);
x_2 = complex_mul(x_2, twiddle_2);
x_3 = complex_mul(x_3, twiddle_3);
float2 minus_i;
minus_i.x = 0;
minus_i.y = -1;
// Hard coded twiddle factors for DFT4
float2 z_0 = x_0 + x_2;
float2 z_1 = x_0 - x_2;
float2 z_2 = x_1 + x_3;
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
float2 y_0 = z_0 + z_2;
float2 y_1 = z_1 + z_3;
float2 y_2 = z_0 - z_2;
float2 y_3 = z_1 - z_3;
int j = ((i - k) << 2) + k;
write_buf[j] = y_0;
write_buf[j + p] = y_1;
write_buf[j + 2 * p] = y_2;
write_buf[j + 3 * p] = y_3;
}
// Each FFT is computed entirely in shared GPU memory.
//
// N is decomposed into radix-2 and radix-4 DFTs:
// e.g. 128 = 2 * 4 * 4 * 4
//
// At each step we use n / 4 threads, each performing
// a single-threaded radix-4 or radix-2 DFT.
//
// We provide the number of radix-2 and radix-4
// steps at compile time for a ~20% performance boost.
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
[[kernel]] void fft(
const device float2* in [[buffer(0)]],
device float2* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]],
uint3 threads_per_grid [[threads_per_grid]]) {
// Index of the DFT in batch
int batch_idx = thread_position_in_grid.x * n;
// The index in the DFT we're working on
int i = thread_position_in_grid.y;
// The number of the threads we're using for each DFT
int m = threads_per_grid.y;
// Allocate 2 shared memory buffers for Stockham.
// We alternate reading from one and writing to the other at each radix step.
threadgroup float2 shared_in[n];
threadgroup float2 shared_out[n];
// Pointers to facilitate Stockham buffer swapping
threadgroup float2* read_buf = shared_in;
threadgroup float2* write_buf = shared_out;
threadgroup float2* tmp;
// Copy input into shared memory
shared_in[i] = in[batch_idx + i];
shared_in[i + m] = in[batch_idx + i + m];
shared_in[i + 2 * m] = in[batch_idx + i + 2 * m];
shared_in[i + 3 * m] = in[batch_idx + i + 3 * m];
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
for (size_t r = 0; r < radix_2_steps; r++) {
radix2(i, p, m * 2, read_buf, write_buf);
radix2(i + m, p, m * 2, read_buf, write_buf);
p *= 2;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stockham switch of buffers
tmp = write_buf;
write_buf = read_buf;
read_buf = tmp;
}
for (size_t r = 0; r < radix_4_steps; r++) {
radix4(i, p, m, read_buf, write_buf);
p *= 4;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stockham switch of buffers
tmp = write_buf;
write_buf = read_buf;
read_buf = tmp;
}
// Copy shared memory to output
out[batch_idx + i] = read_buf[i];
out[batch_idx + i + m] = read_buf[i + m];
out[batch_idx + i + 2 * m] = read_buf[i + 2 * m];
out[batch_idx + i + 3 * m] = read_buf[i + 3 * m];
}
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
template [[host_name("fft_" #name)]] [[kernel]] void \
fft<n, radix_2_steps, radix_4_steps>( \
const device float2* in [[buffer(0)]], \
device float2* out [[buffer(1)]], \
uint3 thread_position_in_grid [[thread_position_in_grid]], \
uint3 threads_per_grid [[threads_per_grid]]);
// Explicitly define kernels for each power of 2.
// clang-format off
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
instantiate_fft(512, 512, 1, 4)
instantiate_fft(1024, 1024, 0, 5)
// 2048 is the max that will fit into 32KB of threadgroup memory.
// TODO: implement 4 step FFT for larger n.
instantiate_fft(2048, 2048, 1, 5) // clang-format on
#define instantiate_ffts(tg_mem_size) \
instantiate_fft(tg_mem_size, float2, float2) \
instantiate_fft(tg_mem_size, float, float2) \
instantiate_fft(tg_mem_size, float2, float) \
instantiate_rader(tg_mem_size, float2, float2) \
instantiate_rader(tg_mem_size, float, float2) \
instantiate_rader(tg_mem_size, float2, float) \
instantiate_bluestein(tg_mem_size, float2, float2) \
instantiate_bluestein(tg_mem_size, float, float2) \
instantiate_bluestein(tg_mem_size, float2, float) \
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/false) \
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/false) \
instantiate_four_step(tg_mem_size, float, float2, 0, /*real=*/true) \
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/true) \
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/true) \
instantiate_four_step(tg_mem_size, float2, float, 1, /*real=*/true)
// It's substantially faster to statically define the
// threadgroup memory size rather than using
// `setThreadgroupMemoryLength` on the compute encoder.
// For non-power of 2 sizes we round up the shared memory.
instantiate_ffts(256)
instantiate_ffts(512)
instantiate_ffts(1024)
instantiate_ffts(2048)
// 4096 is the max that will fit into 32KB of threadgroup memory.
instantiate_ffts(4096) // clang-format on

View File

@@ -0,0 +1,328 @@
// Copyright © 2024 Apple Inc.
/* Radix kernels
We provide optimized, single threaded Radix codelets
for n=2,3,4,5,6,7,8,10,11,12,13.
For n=2,3,4,5,6 we hand write the codelets.
For n=8,10,12 we combine smaller codelets.
For n=7,11,13 we use Rader's algorithm which decomposes
them into (n-1)=6,10,12 codelets. */
#pragma once
#include <metal_common>
#include <metal_math>
#include <metal_stdlib>
METAL_FUNC float2 complex_mul(float2 a, float2 b) {
return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}
// Complex mul followed by conjugate
METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {
return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x);
}
// Compute an FFT twiddle factor
METAL_FUNC float2 get_twiddle(int k, int p) {
float theta = -2.0f * k * M_PI_F / p;
float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)};
return twiddle;
}
METAL_FUNC void radix2(thread float2* x, thread float2* y) {
y[0] = x[0] + x[1];
y[1] = x[0] - x[1];
}
METAL_FUNC void radix3(thread float2* x, thread float2* y) {
float pi_2_3 = -0.8660254037844387;
float2 a_1 = x[1] + x[2];
float2 a_2 = x[1] - x[2];
y[0] = x[0] + a_1;
float2 b_1 = x[0] - 0.5 * a_1;
float2 b_2 = pi_2_3 * a_2;
float2 b_2_j = {-b_2.y, b_2.x};
y[1] = b_1 + b_2_j;
y[2] = b_1 - b_2_j;
}
METAL_FUNC void radix4(thread float2* x, thread float2* y) {
float2 z_0 = x[0] + x[2];
float2 z_1 = x[0] - x[2];
float2 z_2 = x[1] + x[3];
float2 z_3 = x[1] - x[3];
float2 z_3_i = {z_3.y, -z_3.x};
y[0] = z_0 + z_2;
y[1] = z_1 + z_3_i;
y[2] = z_0 - z_2;
y[3] = z_1 - z_3_i;
}
METAL_FUNC void radix5(thread float2* x, thread float2* y) {
float2 root_5_4 = 0.5590169943749475;
float2 sin_2pi_5 = 0.9510565162951535;
float2 sin_1pi_5 = 0.5877852522924731;
float2 a_1 = x[1] + x[4];
float2 a_2 = x[2] + x[3];
float2 a_3 = x[1] - x[4];
float2 a_4 = x[2] - x[3];
float2 a_5 = a_1 + a_2;
float2 a_6 = root_5_4 * (a_1 - a_2);
float2 a_7 = x[0] - a_5 / 4;
float2 a_8 = a_7 + a_6;
float2 a_9 = a_7 - a_6;
float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4;
float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4;
float2 a_10_j = {a_10.y, -a_10.x};
float2 a_11_j = {a_11.y, -a_11.x};
y[0] = x[0] + a_5;
y[1] = a_8 + a_10_j;
y[2] = a_9 + a_11_j;
y[3] = a_9 - a_11_j;
y[4] = a_8 - a_10_j;
}
METAL_FUNC void radix6(thread float2* x, thread float2* y) {
float sin_pi_3 = 0.8660254037844387;
float2 a_1 = x[2] + x[4];
float2 a_2 = x[0] - a_1 / 2;
float2 a_3 = sin_pi_3 * (x[2] - x[4]);
float2 a_4 = x[5] + x[1];
float2 a_5 = x[3] - a_4 / 2;
float2 a_6 = sin_pi_3 * (x[5] - x[1]);
float2 a_7 = x[0] + a_1;
float2 a_3_i = {a_3.y, -a_3.x};
float2 a_6_i = {a_6.y, -a_6.x};
float2 a_8 = a_2 + a_3_i;
float2 a_9 = a_2 - a_3_i;
float2 a_10 = x[3] + a_4;
float2 a_11 = a_5 + a_6_i;
float2 a_12 = a_5 - a_6_i;
y[0] = a_7 + a_10;
y[1] = a_8 - a_11;
y[2] = a_9 + a_12;
y[3] = a_7 - a_10;
y[4] = a_8 + a_11;
y[5] = a_9 - a_12;
}
METAL_FUNC void radix7(thread float2* x, thread float2* y) {
// Rader's algorithm
float2 inv = {1 / 6.0, -1 / 6.0};
// fft
float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]};
radix6(in1, y + 1);
y[0] = y[1] + x[0];
// b_q
y[1] = complex_mul_conj(y[1], float2(-1, 0));
y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879));
y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629));
y[4] = complex_mul_conj(y[4], float2(0, -2.64575131));
y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629));
y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879));
// ifft
radix6(y + 1, x + 1);
y[1] = x[1] * inv + x[0];
y[5] = x[2] * inv + x[0];
y[4] = x[3] * inv + x[0];
y[6] = x[4] * inv + x[0];
y[2] = x[5] * inv + x[0];
y[3] = x[6] * inv + x[0];
}
METAL_FUNC void radix8(thread float2* x, thread float2* y) {
float cos_pi_4 = 0.7071067811865476;
float2 w_0 = {cos_pi_4, -cos_pi_4};
float2 w_1 = {-cos_pi_4, -cos_pi_4};
float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]};
radix4(temp, x);
radix4(temp + 4, x + 4);
y[0] = x[0] + x[4];
y[4] = x[0] - x[4];
float2 x_5 = complex_mul(x[5], w_0);
y[1] = x[1] + x_5;
y[5] = x[1] - x_5;
float2 x_6 = {x[6].y, -x[6].x};
y[2] = x[2] + x_6;
y[6] = x[2] - x_6;
float2 x_7 = complex_mul(x[7], w_1);
y[3] = x[3] + x_7;
y[7] = x[3] - x_7;
}
template <bool raders_perm>
METAL_FUNC void radix10(thread float2* x, thread float2* y) {
float2 w[4];
w[0] = {0.8090169943749475, -0.5877852522924731};
w[1] = {0.30901699437494745, -0.9510565162951535};
w[2] = {-w[1].x, w[1].y};
w[3] = {-w[0].x, w[0].y};
if (raders_perm) {
float2 temp[10] = {
x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]};
radix5(temp, x);
radix5(temp + 5, x + 5);
} else {
float2 temp[10] = {
x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]};
radix5(temp, x);
radix5(temp + 5, x + 5);
}
y[0] = x[0] + x[5];
y[5] = x[0] - x[5];
for (int t = 1; t < 5; t++) {
float2 a = complex_mul(x[t + 5], w[t - 1]);
y[t] = x[t] + a;
y[t + 5] = x[t] - a;
}
}
METAL_FUNC void radix11(thread float2* x, thread float2* y) {
// Raders Algorithm
float2 inv = {1 / 10.0, -1 / 10.0};
// fft
radix10<true>(x + 1, y + 1);
y[0] = y[1] + x[0];
// b_q
y[1] = complex_mul_conj(y[1], float2(-1, 0));
y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649));
y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656));
y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479));
y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150));
y[6] = complex_mul_conj(y[6], float2(0, -3.31662479));
y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150));
y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479));
y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656));
y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649));
// ifft
radix10<false>(y + 1, x + 1);
y[1] = x[1] * inv + x[0];
y[6] = x[2] * inv + x[0];
y[3] = x[3] * inv + x[0];
y[7] = x[4] * inv + x[0];
y[9] = x[5] * inv + x[0];
y[10] = x[6] * inv + x[0];
y[5] = x[7] * inv + x[0];
y[8] = x[8] * inv + x[0];
y[4] = x[9] * inv + x[0];
y[2] = x[10] * inv + x[0];
}
template <bool raders_perm>
METAL_FUNC void radix12(thread float2* x, thread float2* y) {
float2 w[6];
float sin_pi_3 = 0.8660254037844387;
w[0] = {sin_pi_3, -0.5};
w[1] = {0.5, -sin_pi_3};
w[2] = {0, -1};
w[3] = {-0.5, -sin_pi_3};
w[4] = {-sin_pi_3, -0.5};
if (raders_perm) {
float2 temp[12] = {
x[0],
x[3],
x[2],
x[11],
x[8],
x[9],
x[1],
x[7],
x[5],
x[10],
x[4],
x[6]};
radix6(temp, x);
radix6(temp + 6, x + 6);
} else {
float2 temp[12] = {
x[0],
x[2],
x[4],
x[6],
x[8],
x[10],
x[1],
x[3],
x[5],
x[7],
x[9],
x[11]};
radix6(temp, x);
radix6(temp + 6, x + 6);
}
y[0] = x[0] + x[6];
y[6] = x[0] - x[6];
for (int t = 1; t < 6; t++) {
float2 a = complex_mul(x[t + 6], w[t - 1]);
y[t] = x[t] + a;
y[t + 6] = x[t] - a;
}
}
METAL_FUNC void radix13(thread float2* x, thread float2* y) {
// Raders Algorithm
float2 inv = {1 / 12.0, -1 / 12.0};
// fft
radix12<true>(x + 1, y + 1);
y[0] = y[1] + x[0];
// b_q
y[1] = complex_mul_conj(y[1], float2(-1, 0));
y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669));
y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823));
y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161));
y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690));
y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267));
y[7] = complex_mul_conj(y[7], float2(3.60555128, 0));
y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267));
y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690));
y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161));
y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823));
y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669));
// ifft
radix12<false>(y + 1, x + 1);
y[1] = x[1] * inv + x[0];
y[7] = x[2] * inv + x[0];
y[10] = x[3] * inv + x[0];
y[5] = x[4] * inv + x[0];
y[9] = x[5] * inv + x[0];
y[11] = x[6] * inv + x[0];
y[12] = x[7] * inv + x[0];
y[6] = x[8] * inv + x[0];
y[3] = x[9] * inv + x[0];
y[8] = x[10] * inv + x[0];
y[4] = x[11] * inv + x[0];
y[2] = x[12] * inv + x[0];
}

View File

@@ -0,0 +1,622 @@
// Copyright © 2024 Apple Inc.
#include <metal_common>
#include "mlx/backend/metal/kernels/fft/radix.h"
/* FFT helpers for reading and writing from/to device memory.
For many sizes, GPU FFTs are memory bandwidth bound so
read/write performance is important.
Where possible, we read 128 bits sequentially in each thread,
coalesced with accesses from adajcent threads for optimal performance.
We implement specialized reading/writing for:
- FFT
- RFFT
- IRFFT
Each with support for:
- Contiguous reads
- Padded reads
- Strided reads
*/
#define MAX_RADIX 13
using namespace metal;
template <
typename in_T,
typename out_T,
int step = 0,
bool four_step_real = false>
struct ReadWriter {
const device in_T* in;
threadgroup float2* buf;
device out_T* out;
int n;
int batch_size;
int elems_per_thread;
uint3 elem;
uint3 grid;
int threads_per_tg;
bool inv;
// Used for strided access
int strided_device_idx = 0;
int strided_shared_idx = 0;
METAL_FUNC ReadWriter(
const device in_T* in_,
threadgroup float2* buf_,
device out_T* out_,
const short n_,
const int batch_size_,
const short elems_per_thread_,
const uint3 elem_,
const uint3 grid_,
const bool inv_)
: in(in_),
buf(buf_),
out(out_),
n(n_),
batch_size(batch_size_),
elems_per_thread(elems_per_thread_),
elem(elem_),
grid(grid_),
inv(inv_) {
// Account for padding on last threadgroup
threads_per_tg = elem.x == grid.x - 1
? (batch_size - (grid.x - 1) * grid.y) * grid.z
: grid.y * grid.z;
}
// ifft(x) = 1/n * conj(fft(conj(x)))
METAL_FUNC float2 post_in(float2 elem) const {
return inv ? float2(elem.x, -elem.y) : elem;
}
// Handle float case for generic RFFT alg
METAL_FUNC float2 post_in(float elem) const {
return float2(elem, 0);
}
METAL_FUNC float2 pre_out(float2 elem) const {
return inv ? float2(elem.x / n, -elem.y / n) : elem;
}
METAL_FUNC float2 pre_out(float2 elem, int length) const {
return inv ? float2(elem.x / length, -elem.y / length) : elem;
}
METAL_FUNC bool out_of_bounds() const {
// Account for possible extra threadgroups
int grid_index = elem.x * grid.y + elem.y;
return grid_index >= batch_size;
}
METAL_FUNC void load() const {
int batch_idx = elem.x * grid.y * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
// 2 complex64s = 128 bits
constexpr int read_width = 2;
for (short e = 0; e < (elems_per_thread / read_width); e++) {
short index = read_width * tg_idx + read_width * threads_per_tg * e;
index = metal::min(index, max_index);
// vectorized reads
buf[index] = post_in(in[batch_idx + index]);
buf[index + 1] = post_in(in[batch_idx + index + 1]);
}
max_index += 1;
if (elems_per_thread % 2 != 0) {
short index = tg_idx +
read_width * threads_per_tg * (elems_per_thread / read_width);
index = metal::min(index, max_index);
buf[index] = post_in(in[batch_idx + index]);
}
}
METAL_FUNC void write() const {
int batch_idx = elem.x * grid.y * n;
short tg_idx = elem.y * grid.z + elem.z;
short max_index = grid.y * n - 2;
constexpr int read_width = 2;
for (short e = 0; e < (elems_per_thread / read_width); e++) {
short index = read_width * tg_idx + read_width * threads_per_tg * e;
index = metal::min(index, max_index);
// vectorized reads
out[batch_idx + index] = pre_out(buf[index]);
out[batch_idx + index + 1] = pre_out(buf[index + 1]);
}
max_index += 1;
if (elems_per_thread % 2 != 0) {
short index = tg_idx +
read_width * threads_per_tg * (elems_per_thread / read_width);
index = metal::min(index, max_index);
out[batch_idx + index] = pre_out(buf[index]);
}
}
// Padded IO for Bluestein's algorithm
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
threadgroup float2* seq_buf = buf + elem.y * n;
for (int e = 0; e < elems_per_thread; e++) {
int index = metal::min(fft_idx + e * m, n - 1);
if (index < length) {
float2 elem = post_in(in[batch_idx + index]);
seq_buf[index] = complex_mul(elem, w_k[index]);
} else {
seq_buf[index] = 0.0;
}
}
}
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length + elem.y * length;
int fft_idx = elem.z;
int m = grid.z;
float2 inv_factor = {1.0f / n, -1.0f / n};
threadgroup float2* seq_buf = buf + elem.y * n;
for (int e = 0; e < elems_per_thread; e++) {
int index = metal::min(fft_idx + e * m, n - 1);
if (index < length) {
float2 elem = seq_buf[index + length - 1] * inv_factor;
out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length);
}
}
}
// Strided IO for four step FFT
METAL_FUNC void compute_strided_indices(int stride, int overall_n) {
// Use the batch threadgroup dimension to coalesce memory accesses:
// e.g. stride = 12
// device | shared mem
// 0 1 2 3 | 0 12 - -
// - - - - | 1 13 - -
// - - - - | 2 14 - -
// 12 13 14 15 | 3 15 - -
int coalesce_width = grid.y;
int tg_idx = elem.y * grid.z + elem.z;
int outer_batch_size = stride / coalesce_width;
int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
overall_n * (elem.x / outer_batch_size);
strided_device_idx = strided_batch_idx +
tg_idx / coalesce_width * elems_per_thread * stride +
tg_idx % coalesce_width;
strided_shared_idx = (tg_idx % coalesce_width) * n +
tg_idx / coalesce_width * elems_per_thread;
}
// Four Step FFT First Step
METAL_FUNC void load_strided(int stride, int overall_n) {
compute_strided_indices(stride, overall_n);
for (int e = 0; e < elems_per_thread; e++) {
buf[strided_shared_idx + e] =
post_in(in[strided_device_idx + e * stride]);
}
}
METAL_FUNC void write_strided(int stride, int overall_n) {
for (int e = 0; e < elems_per_thread; e++) {
float2 output = buf[strided_shared_idx + e];
int combined_idx = (strided_device_idx + e * stride) % overall_n;
int ij = (combined_idx / stride) * (combined_idx % stride);
// Apply four step twiddles at end of first step
float2 twiddle = get_twiddle(ij, overall_n);
out[strided_device_idx + e * stride] = complex_mul(output, twiddle);
}
}
};
// Four Step FFT Second Step
template <>
METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::load_strided(
int stride,
int overall_n) {
// Silence compiler warnings
(void)stride;
(void)overall_n;
// Don't invert between steps
bool default_inv = inv;
inv = false;
load();
inv = default_inv;
}
template <>
METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::write_strided(
int stride,
int overall_n) {
compute_strided_indices(stride, overall_n);
for (int e = 0; e < elems_per_thread; e++) {
float2 output = buf[strided_shared_idx + e];
out[strided_device_idx + e * stride] = pre_out(output, overall_n);
}
}
// For RFFT, we interleave batches of two real sequences into one complex one:
//
// z_k = x_k + j.y_k
// X_k = (Z_k + Z_(N-k)*) / 2
// Y_k = -j * ((Z_k - Z_(N-k)*) / 2)
//
// This roughly doubles the throughput over the regular FFT.
template <>
METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
int grid_index = elem.x * grid.y + elem.y;
// We pack two sequences into one for RFFTs
return grid_index * 2 >= batch_size;
}
template <>
METAL_FUNC void ReadWriter<float, float2>::load() const {
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
int grid_index = elem.x * grid.y + elem.y;
short next_in =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
short m = grid.z;
short fft_idx = elem.z;
for (int e = 0; e < elems_per_thread; e++) {
int index = metal::min(fft_idx + e * m, n - 1);
seq_buf[index].x = in[batch_idx + index];
seq_buf[index].y = in[batch_idx + index + next_in];
}
}
template <>
METAL_FUNC void ReadWriter<float, float2>::write() const {
short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
int grid_index = elem.x * grid.y + elem.y;
short next_out =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
float2 conj = {1, -1};
float2 minus_j = {0, -1};
short m = grid.z;
short fft_idx = elem.z;
for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
int index = metal::min(fft_idx + e * m, n_over_2 - 1);
// x_0 = z_0.real
// y_0 = z_0.imag
if (index == 0) {
out[batch_idx + index] = {seq_buf[index].x, 0};
out[batch_idx + index + next_out] = {seq_buf[index].y, 0};
} else {
float2 x_k = seq_buf[index];
float2 x_n_minus_k = seq_buf[n - index] * conj;
out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
out[batch_idx + index + next_out] =
complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
}
}
}
template <>
METAL_FUNC void ReadWriter<float, float2>::load_padded(
int length,
const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
int grid_index = elem.x * grid.y + elem.y;
short next_in =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
short m = grid.z;
short fft_idx = elem.z;
for (int e = 0; e < elems_per_thread; e++) {
int index = metal::min(fft_idx + e * m, n - 1);
if (index < length) {
float2 elem =
float2(in[batch_idx + index], in[batch_idx + index + next_in]);
seq_buf[index] = complex_mul(elem, w_k[index]);
} else {
seq_buf[index] = 0;
}
}
}
template <>
METAL_FUNC void ReadWriter<float, float2>::write_padded(
int length,
const device float2* w_k) const {
int length_over_2 = (length / 2) + 1;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;
short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
? 0
: length_over_2;
float2 conj = {1, -1};
float2 inv_factor = {1.0f / n, -1.0f / n};
float2 minus_j = {0, -1};
short m = grid.z;
short fft_idx = elem.z;
for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
int index = metal::min(fft_idx + e * m, length_over_2 - 1);
// x_0 = z_0.real
// y_0 = z_0.imag
if (index == 0) {
float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor);
out[batch_idx + index] = float2(elem.x, 0);
out[batch_idx + index + next_out] = float2(elem.y, 0);
} else {
float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor);
float2 x_n_minus_k = complex_mul(
w_k[length - index], seq_buf[length - index] * inv_factor);
x_n_minus_k *= conj;
// w_k should happen before this extraction
out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
out[batch_idx + index + next_out] =
complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
}
}
}
// For IRFFT, we do the opposite
//
// Z_k = X_k + j.Y_k
// x_k = Re(Z_k)
// Y_k = Imag(Z_k)
template <>
METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
int grid_index = elem.x * grid.y + elem.y;
// We pack two sequences into one for IRFFTs
return grid_index * 2 >= batch_size;
}
template <>
METAL_FUNC void ReadWriter<float2, float>::load() const {
short n_over_2 = (n / 2) + 1;
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
int grid_index = elem.x * grid.y + elem.y;
short next_in =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
short m = grid.z;
short fft_idx = elem.z;
float2 conj = {1, -1};
float2 plus_j = {0, 1};
for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
int index = metal::min(fft_idx + t * m, n_over_2 - 1);
float2 x = in[batch_idx + index];
float2 y = in[batch_idx + index + next_in];
// NumPy forces first input to be real
bool first_val = index == 0;
// NumPy forces last input on even irffts to be real
bool last_val = n % 2 == 0 && index == n_over_2 - 1;
if (first_val || last_val) {
x = float2(x.x, 0);
y = float2(y.x, 0);
}
seq_buf[index] = x + complex_mul(y, plus_j);
seq_buf[index].y = -seq_buf[index].y;
if (index > 0 && !last_val) {
seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j);
seq_buf[n - index].y = -seq_buf[n - index].y;
}
}
}
template <>
METAL_FUNC void ReadWriter<float2, float>::write() const {
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
int grid_index = elem.x * grid.y + elem.y;
short next_out =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
short m = grid.z;
short fft_idx = elem.z;
for (int e = 0; e < elems_per_thread; e++) {
int index = metal::min(fft_idx + e * m, n - 1);
out[batch_idx + index] = seq_buf[index].x / n;
out[batch_idx + index + next_out] = seq_buf[index].y / -n;
}
}
template <>
METAL_FUNC void ReadWriter<float2, float>::load_padded(
int length,
const device float2* w_k) const {
int n_over_2 = (n / 2) + 1;
int length_over_2 = (length / 2) + 1;
int batch_idx =
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
threadgroup float2* seq_buf = buf + elem.y * n;
// No out of bounds accesses on odd batch sizes
int grid_index = elem.x * grid.y + elem.y;
short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
? 0
: length_over_2;
short m = grid.z;
short fft_idx = elem.z;
float2 conj = {1, -1};
float2 plus_j = {0, 1};
for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
int index = metal::min(fft_idx + t * m, n_over_2 - 1);
float2 x = in[batch_idx + index];
float2 y = in[batch_idx + index + next_in];
if (index < length_over_2) {
bool last_val = length % 2 == 0 && index == length_over_2 - 1;
if (last_val) {
x = float2(x.x, 0);
y = float2(y.x, 0);
}
float2 elem1 = x + complex_mul(y, plus_j);
seq_buf[index] = complex_mul(elem1 * conj, w_k[index]);
if (index > 0 && !last_val) {
float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j);
seq_buf[length - index] =
complex_mul(elem2 * conj, w_k[length - index]);
}
} else {
short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2);
seq_buf[pad_index] = 0;
seq_buf[pad_index + 1] = 0;
}
}
}
template <>
METAL_FUNC void ReadWriter<float2, float>::write_padded(
int length,
const device float2* w_k) const {
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
int grid_index = elem.x * grid.y + elem.y;
short next_out =
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
short m = grid.z;
short fft_idx = elem.z;
float2 inv_factor = {1.0f / n, -1.0f / n};
for (int e = 0; e < elems_per_thread; e++) {
int index = fft_idx + e * m;
if (index < length) {
float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]);
out[batch_idx + index] = output.x / length;
out[batch_idx + index + next_out] = output.y / -length;
}
}
}
// Four Step RFFT
template <>
METAL_FUNC void
ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::load_strided(
int stride,
int overall_n) {
// Silence compiler warnings
(void)stride;
(void)overall_n;
// Don't invert between steps
bool default_inv = inv;
inv = false;
load();
inv = default_inv;
}
template <>
METAL_FUNC void
ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::write_strided(
int stride,
int overall_n) {
int overall_n_over_2 = overall_n / 2 + 1;
int coalesce_width = grid.y;
int tg_idx = elem.y * grid.z + elem.z;
int outer_batch_size = stride / coalesce_width;
int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
overall_n_over_2 * (elem.x / outer_batch_size);
strided_device_idx = strided_batch_idx +
tg_idx / coalesce_width * elems_per_thread / 2 * stride +
tg_idx % coalesce_width;
strided_shared_idx = (tg_idx % coalesce_width) * n +
tg_idx / coalesce_width * elems_per_thread / 2;
for (int e = 0; e < elems_per_thread / 2; e++) {
float2 output = buf[strided_shared_idx + e];
out[strided_device_idx + e * stride] = output;
}
// Add on n/2 + 1 element
if (tg_idx == 0 && elem.x % outer_batch_size == 0) {
out[strided_batch_idx + overall_n / 2] = buf[n / 2];
}
}
// Four Step IRFFT
template <>
METAL_FUNC void
ReadWriter<float2, float2, /*step=*/0, /*real=*/true>::load_strided(
int stride,
int overall_n) {
int overall_n_over_2 = overall_n / 2 + 1;
auto conj = float2(1, -1);
compute_strided_indices(stride, overall_n);
// Translate indices in terms of N - k
for (int e = 0; e < elems_per_thread; e++) {
int device_idx = strided_device_idx + e * stride;
int overall_batch = device_idx / overall_n;
int overall_index = device_idx % overall_n;
if (overall_index < overall_n_over_2) {
device_idx -= overall_batch * (overall_n - overall_n_over_2);
buf[strided_shared_idx + e] = in[device_idx] * conj;
} else {
int conj_idx = overall_n - overall_index;
device_idx = overall_batch * overall_n_over_2 + conj_idx;
buf[strided_shared_idx + e] = in[device_idx];
}
}
}
template <>
METAL_FUNC void
ReadWriter<float2, float, /*step=*/1, /*real=*/true>::load_strided(
int stride,
int overall_n) {
// Silence compiler warnings
(void)stride;
(void)overall_n;
bool default_inv = inv;
inv = false;
load();
inv = default_inv;
}
template <>
METAL_FUNC void
ReadWriter<float2, float, /*step=*/1, /*real=*/true>::write_strided(
int stride,
int overall_n) {
compute_strided_indices(stride, overall_n);
for (int e = 0; e < elems_per_thread; e++) {
out[strided_device_idx + e * stride] =
pre_out(buf[strided_shared_idx + e], overall_n).x;
}
}

View File

@@ -0,0 +1,45 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}

View File

@@ -1,173 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include <metal_atomic>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/indexing.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]],
device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto ind_idx = index.x;
auto ind_offset = index.y;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
size_t idx_loc;
if (IDX_NDIM == 0) {
idx_loc = 0;
} else if (IDX_NDIM == 1) {
idx_loc = ind_idx * indices.strides[indices.ndim * i];
} else {
idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
out[out_idx] = src[src_offset + src_idx];
}
#define make_gather_impl(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
[[kernel]] void gather( \
const device T* src [[buffer(0)]], \
device T* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \
const constant size_t* src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int* slice_sizes [[buffer(5)]], \
const constant int* axes [[buffer(6)]], \
const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]) { \
Indices<IdxT, NIDX> idxs{ \
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
\
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
src, \
out, \
src_shape, \
src_strides, \
src_ndim, \
slice_sizes, \
axes, \
idxs, \
index, \
grid_dim); \
}
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
make_gather(10)
/////////////////////////////////////////////////////////////////////
// Gather instantiations
/////////////////////////////////////////////////////////////////////
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
gather<src_t, idx_t, nidx, nd>( \
const device src_t* src [[buffer(0)]], \
device src_t* out [[buffer(1)]], \
const constant int* src_shape [[buffer(2)]], \
const constant size_t* src_strides [[buffer(3)]], \
const constant size_t& src_ndim [[buffer(4)]], \
const constant int* slice_sizes [[buffer(5)]], \
const constant int* axes [[buffer(6)]], \
const constant int* idx_shapes [[buffer(7)]], \
const constant size_t* idx_strides [[buffer(8)]], \
const constant int& idx_ndim [[buffer(9)]], \
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]);
// clang-format off
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
// clang-format off
#define instantiate_gather4(name, src_t, idx_t, nidx) \
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
// clang-format off
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
// clang-format off
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t) // clang-format on

View File

@@ -1,13 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Indexing utils
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX>
struct Indices {
const array<const device IdxT*, NIDX> buffers;
@@ -24,31 +20,3 @@ METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
}
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
#define IDX_ARG_0(idx_t)
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
#define IDX_ARR_N(n) idx##n,
#define IDX_ARR_0()
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)

View File

@@ -601,14 +601,18 @@ METAL_FUNC void qvm_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 8;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = 32 / bits;
constexpr int tn = 32 / pack_factor;
constexpr int blocksize = SIMD_SIZE;
typedef float U;
typedef struct {
uint32_t wi[tn];
} vec_w;
thread uint32_t w_local;
thread U result[pack_factor] = {0};
thread vec_w w_local;
thread U result[tn * pack_factor] = {0};
thread U scale = 1;
thread U bias = 0;
thread U x_local = 0;
@@ -616,11 +620,12 @@ METAL_FUNC void qvm_impl(
// Adjust positions
const int out_vec_size_w = out_vec_size / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
w += out_col / pack_factor;
scales += out_col / group_size;
biases += out_col / group_size;
x += tid.y * in_vec_size;
int out_col =
tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
w += out_col / pack_factor + simd_lid * out_vec_size_w;
scales += out_col / group_size + simd_lid * out_vec_size_g;
biases += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.y * in_vec_size + simd_lid;
y += tid.y * out_vec_size + out_col;
if (out_col >= out_vec_size) {
@@ -628,40 +633,61 @@ METAL_FUNC void qvm_impl(
}
// Loop over in_vec in blocks of blocksize
int i = 0;
for (; i + blocksize <= in_vec_size; i += blocksize) {
x_local = x[i + simd_lid];
scale = scales[(i + simd_lid) * out_vec_size_g];
bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w];
int remaining = in_vec_size % blocksize;
if (remaining == 0) {
for (int i = 0; i < in_vec_size; i += blocksize) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
qouter<U, pack_factor, bits>(
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
}
} else {
for (int i = blocksize; i < in_vec_size; i += blocksize) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
}
if (static_cast<int>(simd_lid) < remaining) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
} else {
x_local = 0;
scale = 0;
bias = 0;
}
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
}
if (static_cast<int>(i + simd_lid) < in_vec_size) {
x_local = x[i + simd_lid];
scale = scales[(i + simd_lid) * out_vec_size_g];
bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w];
} else {
x_local = 0;
scale = 0;
bias = 0;
w_local = 0;
}
qouter<U, pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
// Accumulate in the simdgroup
#pragma clang loop unroll(full)
for (int k = 0; k < pack_factor; k++) {
for (int k = 0; k < tn * pack_factor; k++) {
result[k] = simd_sum(result[k]);
}
// Store the result
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k = 0; k < pack_factor; k++) {
for (int k = 0; k < tn * pack_factor; k++) {
y[k] = static_cast<T>(result[k]);
}
}

View File

@@ -0,0 +1,4 @@
#pragma once
#include "mlx/backend/metal/kernels/reduction/reduce_all.h"
#include "mlx/backend/metal/kernels/reduction/reduce_col.h"
#include "mlx/backend/metal/kernels/reduction/reduce_row.h"

View File

@@ -0,0 +1,293 @@
// Copyright © 2024 Apple Inc.
#include <metal_atomic>
#include <metal_simdgroup>
// clang-format off
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
#include "mlx/backend/metal/kernels/reduce.h"
#define instantiate_reduce_helper_floats(inst_f, name, op) \
inst_f(name, float16, half, op) \
inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
#define instantiate_reduce_helper_uints(inst_f, name, op) \
inst_f(name, uint8, uint8_t, op) \
inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
#define instantiate_reduce_helper_ints(inst_f, name, op) \
inst_f(name, int8, int8_t, op) \
inst_f(name, int16, int16_t, op) \
inst_f(name, int32, int32_t, op)
#define instantiate_reduce_helper_64b(inst_f, name, op) \
inst_f(name, int64, int64_t, op) \
inst_f(name, uint64, uint64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \
instantiate_reduce_helper_ints(inst_f, name, op)
#define instantiate_reduce_ops(inst_f, type_f) \
type_f(inst_f, sum, Sum) \
type_f(inst_f, prod, Prod) \
type_f(inst_f, min, Min) \
type_f(inst_f, max, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint64, uint64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i_reduce_" #name)]] [[kernel]] void \
init_reduce<otype, op>( \
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \
all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \
col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template \
[[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \
col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or<bool>)
#define instantiate_row_reduce_small(name, itype, otype, op) \
template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \
row_reduce_general_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint lid [[thread_position_in_grid]]); \
template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \
row_reduce_general_med<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template \
[[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \
row_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op) \
template \
[[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
row_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& out_size [[buffer(3)]], \
const constant size_t& non_row_reductions [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>)
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
// clang-format on

View File

@@ -0,0 +1,6 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/reduction/ops.h"

View File

@@ -1,32 +0,0 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Reduce init
///////////////////////////////////////////////////////////////////////////////
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}
#define instantiate_init_reduce(name, otype, op) \
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And)
instantiate_init_reduce(orbool_, bool, Or) // clang-format on

View File

@@ -5,9 +5,7 @@
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
static constant constexpr const uint8_t simd_size = 32;
union bool4_or_uint {
bool4 b;
@@ -21,6 +19,7 @@ struct None {
}
};
template <typename U = bool>
struct And {
bool simd_reduce(bool val) {
return simd_all(val);
@@ -58,6 +57,7 @@ struct And {
}
};
template <typename U = bool>
struct Or {
bool simd_reduce(bool val) {
return simd_any(val);

View File

@@ -1,11 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// All reduce helper
///////////////////////////////////////////////////////////////////////////////
@@ -139,50 +133,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
out[thread_group_id] = total_val;
}
}
#define instantiate_all_reduce(name, itype, otype, op) \
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
all_reduce<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
all_reduce_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const device size_t& in_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint grid_size [[threads_per_grid]], \
uint simd_per_group [[simdgroups_per_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint thread_group_id [[threadgroup_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
instantiate_all_reduce(name##tname, type, type, op<type>)
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
// clang-format off
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
// special case bool with larger output type
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on

View File

@@ -1,11 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
#include "mlx/backend/metal/kernels/reduction/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// Small column reduce kernel
///////////////////////////////////////////////////////////////////////////////
@@ -52,23 +46,6 @@ template <typename T, typename U, typename Op>
out[out_idx] = total_val;
}
#define instantiate_col_reduce_small(name, itype, otype, op) \
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
col_reduce_small<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
const constant size_t& non_col_reductions [[buffer(8)]], \
const constant int* non_col_shapes [[buffer(9)]], \
const constant size_t* non_col_strides [[buffer(10)]], \
const constant int& non_col_ndim [[buffer(11)]], \
uint tid [[thread_position_in_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Column reduce helper
///////////////////////////////////////////////////////////////////////////////
@@ -186,64 +163,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
}
}
}
#define instantiate_col_reduce_general(name, itype, otype, op) \
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
col_reduce_general<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device mlx_atomic<otype>* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 lsize [[threads_per_threadgroup]]);
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
template \
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
col_reduce_general_no_atomics<itype, otype, op>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t& reduction_size [[buffer(2)]], \
const constant size_t& reduction_stride [[buffer(3)]], \
const constant size_t& out_size [[buffer(4)]], \
const constant int* shape [[buffer(5)]], \
const constant size_t* strides [[buffer(6)]], \
const constant int& ndim [[buffer(7)]], \
threadgroup otype* local_data [[threadgroup(0)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint3 gid [[thread_position_in_grid]], \
uint3 lsize [[threads_per_threadgroup]], \
uint3 gsize [[threads_per_grid]]);
///////////////////////////////////////////////////////////////////////////////
// Instantiations
///////////////////////////////////////////////////////////////////////////////
// clang-format off
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
// clang-format off
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on

View File

@@ -0,0 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
template <typename T, typename Op>
[[kernel]] void init_reduce(
device T* out [[buffer(0)]],
uint tid [[thread_position_in_grid]]) {
out[tid] = Op::init;
}

Some files were not shown because too many files have changed in this diff Show More