Compare commits

..

66 Commits

Author SHA1 Message Date
Awni Hannun
0c1155faf5 binding + tests 2024-12-09 12:57:36 -08:00
Awni Hannun
2b9c24c517 works 2024-12-09 12:57:36 -08:00
Awni Hannun
ee59d50293 try dynamic reshape 2024-12-09 12:57:36 -08:00
Awni Hannun
40c62c1321 Use int64 stride everywhere (#1671)
* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
2024-12-09 11:09:02 -08:00
Awni Hannun
35b412c099 Fix compile hasher for string constants. (#1677)
* fix hash

* add test

* nit
2024-12-09 09:26:18 -08:00
Cheng
d0f471cff7 Using math defines requires switch in MSVC (#1665)
* Using math defines requires switch in MSVC

* Fix more math macros

* Fix type

* Remove _MSC_VER guard for math defines
2024-12-08 08:16:28 -08:00
Cheng
6f316b8bf5 Use int64_t instead of ssize_t (#1673) 2024-12-07 20:10:44 -08:00
Cheng
7c10c93a1f Convert filesystem path to std::string explicitly (#1672) 2024-12-07 20:10:06 -08:00
Cheng
d92ea094f1 Use && instead of and (#1663)
* Use && instead of and

* Remove "and" in ops.cpp
2024-12-07 18:26:39 -08:00
Cheng
6ae5423b4a Do not pass integers to isnan (#1664) 2024-12-07 18:26:23 -08:00
Cheng
9635cffdc8 Include io.h in MSVC for IO functions (#1661) 2024-12-07 18:26:06 -08:00
Cheng
96986fb362 Use auto* for pointers (#1662) 2024-12-07 18:25:40 -08:00
Cheng
3ceb341a75 Use correct complex type for MSVC (#1660) 2024-12-07 18:25:22 -08:00
Awni Hannun
50fa705125 patch bump (#1656) 2024-12-06 13:16:19 -08:00
Awni Hannun
69a2991614 allow compiling lambdas in C++ (#1650)
* allow compiling lambdas in C++

* fix test

* more tests

* auto detect capture-less lambda
2024-12-06 13:13:21 -08:00
mt_caret
fd3377dd1f Support bias correction in Adam and AdamW optimizers (#1640) 2024-12-06 12:13:34 -08:00
Awni Hannun
d0b6cb0425 More primitives for compiling with shapeless (#1653)
* more shapeless and more Shape

* more shape

* fix

* fix
2024-12-06 11:29:18 -08:00
Alex Barron
95c4a2e3af add back conditionaltype (#1655) 2024-12-06 11:12:01 -08:00
Awni Hannun
bc2a29f033 fix (#1654) 2024-12-06 10:48:58 -08:00
Nripesh Niketan
3bb5b4a302 Chore: Add default language in pre-commit and bump hooks (#1652) 2024-12-06 07:54:29 -08:00
Awni Hannun
fc88fd9097 Shape and Strides 1 / N (#1645)
* shape and stride type def

* more shape
2024-12-05 12:53:43 -08:00
Awni Hannun
c5b0928c1f fix fallback (#1646) 2024-12-05 11:59:53 -08:00
Awni Hannun
e047fd977d compile changes if stream changes (#1644) 2024-12-03 14:37:44 -08:00
Jagrit Digani
9d40e521d7 Stop matrix copies with new attention kernel (#1639) 2024-12-02 14:12:38 -08:00
Alex Barron
1445dcaa60 let class predicate specify quantization parameters (#1638) 2024-12-02 14:09:28 -08:00
Jesper Stemann Andersen
e4eeb4e910 Added missing unordered_map includes (#1635)
* Added missing includes in mlx/io.h and mlx/backend/metal/metal.h

* Added additional missing unordered_map includes that fixes build on FreeBSD
2024-12-02 07:03:03 -08:00
Awni Hannun
aa86876813 fix transformer decoder post norm LN (#1637) 2024-12-02 07:02:17 -08:00
Jesper Stemann Andersen
974bb54ab2 CMake: Enabled using Accelerate on x86_64 / x64 (#1625)
* CMake: Enabled using Accelerate on x86_64 / x64

Cf. https://github.com/JuliaPackaging/Yggdrasil/pull/9761

* CMake: Removed superfluous MLX_BUILD_ARM
2024-11-28 10:55:45 -08:00
Ikko Eltociear Ashimine
9bc2183a31 docs: update device.cpp (#1632)
unecessary -> unnecessary
2024-11-27 20:58:26 -08:00
Awni Hannun
d4b222b6d3 Fix some leaks and races (#1629)
* fix leak and fix potential race

* more leak fixes

* fix one more
2024-11-27 20:01:20 -08:00
Jesper Stemann Andersen
af2af818a6 Enables build for *-linux-musl (#1627)
Also contributes to being able to build for *-w64-mingw32.

Cf. https://github.com/JuliaPackaging/Yggdrasil/pull/9761
2024-11-27 13:14:24 -08:00
Jesper Stemann Andersen
698e63a608 CMake: Build with dlfcn-win32 to have dlopen etc. on win32 (#1628)
Cf. https://github.com/JuliaPackaging/Yggdrasil/pull/9761
2024-11-27 13:14:13 -08:00
Awni Hannun
211411faf2 fix large ops (#1620) 2024-11-24 09:17:10 -08:00
Awni Hannun
bb303c45a5 version (#1617) 2024-11-22 12:00:03 -08:00
Alex Barron
6f7986d592 Cleaner qmv/qvm (#1616) 2024-11-22 11:14:08 -08:00
Awni Hannun
7cbb4aef17 Doc fix (#1615) 2024-11-22 11:12:25 -08:00
Jagrit Digani
02bec0bb6d Matrix Attention kernel (#1610)
* Rough INIT

* [WIP]: Loading and Matmuls added

* [WIP]: Reductions and min working aligned kernel at headdim = 64

* [WIP] Added headdim 80 for testing

* [WIP] Update dispatch params for testing

* [WIP] Add support for unaligned seq lengths - still looks messy

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Update sdpa_benchmarks

* Enable gqa support

* Update benchmark and switch off 128 headdim

* Update headdim 128 tuning

* Remove older fast attention code. Write out O strided

* Disable hd=128 until further optimizations

* Enable bf16

* Fix data size bug

* Enable attn build outside of jit
2024-11-22 10:34:05 -08:00
Alex Barron
c79f6a4a8c 3 and 6 bit quantization (#1613)
* Support 3 and 6 bit quantization
2024-11-22 10:22:13 -08:00
Awni Hannun
0c5eea226b Reduce specializations (#1607)
* start of reduce specializations

* fix all reduce

* fix many dims

* fix

* non-jit tests clear

* cleanup instantiations

* cpu merges

* change dim specializations

* optimize

* fix jit

* fix jit

* use higher precision for integer sum+prod

* fixes
2024-11-21 19:53:00 -08:00
Awni Hannun
dcca0d7477 contiguous op / prim (#1612) 2024-11-21 19:51:49 -08:00
Cocoa
0d5e7716ad fix typo: accross -> across (#1609)
Signed-off-by: Cocoa <i@uwucocoa.moe>
2024-11-20 15:30:51 -08:00
Angelos Katharopoulos
d8c824c594 Formatting fixes (#1606) 2024-11-20 15:30:36 -08:00
Saanidhya
cb431dfc9f Adds 3D pooling (#1526) 2024-11-19 16:45:24 -08:00
Awni Hannun
61d787726a Fix view scalar bug segfault (#1603)
* fix view scalar bug

* fix view scalar bug

* one more fix
2024-11-19 10:54:05 -08:00
Angelos Katharopoulos
5e89aace9b Fix concatenate vmap (#1600) 2024-11-19 10:44:04 -08:00
Awni Hannun
2af7e8a9a6 fix cmake version (#1601) 2024-11-19 08:45:05 -08:00
Awni Hannun
2419edd5b2 Faster indexing math in a few kernels (#1589)
* wip: faster compiled kernels

* faster general unary with uint specialization

* index type in compiled, unary, binary, ternary, copy

* fix jit

* jit fix

* specialize gather + scatter

* nit in docs
2024-11-18 19:52:00 -08:00
Awni Hannun
bf481e8e5d Fix sibling leak (#1590)
* add test

* fix + test

* fix fix
2024-11-18 19:17:01 -08:00
Awni Hannun
9d7fa6b8e6 Use osx deployment target to pick Metal version (#1595)
* choose metal based on deployment target rather than system version

* nit

* unused compile def
2024-11-18 19:16:49 -08:00
Angelos Katharopoulos
073076ac7d 2-Pass Sdpa Inference Kernel (#1597) 2024-11-18 17:31:53 -08:00
Awni Hannun
9bd03dd9b4 More buffer donation with no-ops (#1591)
* more donation

* fix test

* fix build
2024-11-18 08:35:41 -08:00
Awni Hannun
6931f84412 fix dispatch threads for a few kernels (#1594) 2024-11-18 08:35:25 -08:00
xnorai
16ec0556a0 Allocate raw JSON metadata buffer on the heap, and limit its size (#1596)
* Allocate raw JSON metadata buffer on the heap, and limit its size to 1GiB

* Set the upper size limit for the header to 100K as in Rust safetensors
2024-11-18 07:22:51 -08:00
Awni Hannun
610af352d4 Dispatch bf16 at run time when using the JIT (#1584)
* Dispatch bf16 at run time when using the JIT

* fix extension

* fix extension build

* fix extension build

* Update utils.h
2024-11-15 16:54:36 -08:00
Awni Hannun
b35f1e3c9c fix donation in sdpa (#1587) 2024-11-13 17:21:13 -08:00
Awni Hannun
dfa0b9aab4 Cpu fast quantize (#1578)
* cpu quantize

* fix
2024-11-08 20:10:39 -08:00
Alex Barron
a4c47b0276 OOB QMV fix (#1579)
* fix oob access in qmv

* skip more

* fix small case
2024-11-08 17:59:45 -08:00
Alex Barron
111fefd5e9 Fix OOB access in qmv (#1577)
* fix oob access in qmv

* skip more
2024-11-08 15:41:30 -08:00
Awni Hannun
c1fe1ef081 Bfs width limit (#1568)
* width limit

* fix

* large limit

* put env vars in env namespace
2024-11-08 15:00:46 -08:00
Awni Hannun
8c34c9dac4 throw for invalid case and remove test (#1575) 2024-11-08 12:04:03 -08:00
Awni Hannun
91c0277356 fix per-example mask + docs in sdpa (#1574) 2024-11-08 11:51:15 -08:00
Awni Hannun
9f0d5c12fc Fully wrap the command encoder (#1572)
* fully wrap the command encoder

* use consistent style + fix extensions
2024-11-08 11:50:21 -08:00
Awni Hannun
59247c2b62 add groups in conv2d (#1569) 2024-11-07 13:57:53 -08:00
Awni Hannun
9a3842a2d9 fix (#1566) 2024-11-06 17:10:33 -08:00
Alex Barron
726dbd9267 v0.20.0 (#1565) 2024-11-05 12:37:57 -08:00
Awni Hannun
54f05e7195 Fix gather vmap (#1563)
* fix gather

* fix
2024-11-05 11:29:20 -08:00
213 changed files with 7990 additions and 6254 deletions

View File

@@ -1,13 +1,14 @@
repos: repos:
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8 rev: v19.1.4
hooks: hooks:
- id: clang-format - id: clang-format
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0 rev: 24.10.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.13.2 rev: 5.13.2
hooks: hooks:

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION) if(NOT MLX_VERSION)
set(MLX_VERSION 0.19.3) set(MLX_VERSION 0.21.1)
endif() endif()
# --------------------- Processor tests ------------------------- # --------------------- Processor tests -------------------------
@@ -34,8 +34,6 @@ message(
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}" "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
) )
set(MLX_BUILD_ARM OFF)
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
if(NOT MLX_ENABLE_X64_MAC) if(NOT MLX_ENABLE_X64_MAC)
@@ -57,10 +55,6 @@ else()
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.") message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
endif() endif()
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm64")
set(MLX_BUILD_ARM ON)
endif()
# ----------------------------- Lib ----------------------------- # ----------------------------- Lib -----------------------------
include(FetchContent) include(FetchContent)
@@ -89,25 +83,27 @@ elseif(MLX_BUILD_METAL)
# Throw an error if xcrun not found # Throw an error if xcrun not found
execute_process( execute_process(
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version" COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
OUTPUT_VARIABLE MACOS_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MACOS_SDK_VERSION COMMAND_ERROR_IS_FATAL ANY)
if(${MACOS_VERSION} LESS 14.0) if(${MACOS_SDK_VERSION} LESS 14.0)
message( message(
FATAL_ERROR FATAL_ERROR
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON") "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
endif() endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}") message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
set(METAL_CPP_URL set(METAL_CPP_URL
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip
) )
# Get the metal version
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
execute_process( execute_process(
COMMAND COMMAND
zsh "-c" zsh "-c"
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL}) FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp) FetchContent_MakeAvailable(metal_cpp)
@@ -115,13 +111,11 @@ elseif(MLX_BUILD_METAL)
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}> mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
$<INSTALL_INTERFACE:include/metal_cpp>) $<INSTALL_INTERFACE:include/metal_cpp>)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB}) target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
endif() endif()
if(MLX_BUILD_CPU) if(MLX_BUILD_CPU)
find_library(ACCELERATE_LIBRARY Accelerate) find_library(ACCELERATE_LIBRARY Accelerate)
if(MLX_BUILD_ARM AND ACCELERATE_LIBRARY) if(ACCELERATE_LIBRARY)
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
set(MLX_BUILD_ACCELERATE ON) set(MLX_BUILD_ACCELERATE ON)
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY}) target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
@@ -160,6 +154,13 @@ if(MLX_BUILD_CPU)
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS}) message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS}) target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES}) target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
if(WIN32)
find_package(dlfcn-win32 REQUIRED)
message(STATUS "dlfcn-win32 lib " ${dlfcn-win32_LIBRARIES})
message(STATUS "dlfcn-win32 include " ${dlfcn-win32_INCLUDE_DIRS})
target_link_libraries(mlx PUBLIC ${dlfcn-win32_LIBRARIES})
endif()
endif() endif()
else() else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)

View File

@@ -1,62 +1,189 @@
# Copyright © 2024 Apple Inc.
import argparse import argparse
import math import math
import os
import subprocess
import time
import mlx.core as mx import mlx.core as mx
from time_utils import time_fn import numpy as np
MAX_SEQ = 300 device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
START_SEQ = 100 device_name = device_name.decode("utf-8").strip("\n")
SEQ_INCREMENT = 50
N_warmup = 5
N_iter_bench = 40
N_iter_func = 8
def time_self_attention_primitives(): def bench(f, *args):
mx.random.seed(3) for i in range(N_warmup):
B = 2 f(*args)
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_primitives(qs, ks, vs, alpha): s = time.perf_counter_ns()
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2) for i in range(N_iter_bench):
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) f(*args)
o = p @ vs e = time.perf_counter_ns()
return o return (e - s) * 1e-9
time_fn(sdpa_primitives, q, k, v, scale)
def time_self_attention_sdpa(): def mlx_sdpa_fused_inner(q, k, v, scale):
mx.random.seed(3) return mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=None)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
time_fn(sdpa_fused, q, k, v, scale) def mlx_sdpa_unfused_inner(q, k, v, scale, f32softmax=False):
q_dtype = q.dtype
q = q * mx.array(scale, q_dtype)
n_q_heads = q.shape[-3]
n_kv_heads = k.shape[-3]
n_repeats = n_q_heads // n_kv_heads
B = q.shape[0]
L = q.shape[2]
if n_repeats > 1:
q = mx.reshape(q, [B, n_kv_heads, n_repeats, L, -1])
k = mx.expand_dims(k, 2)
v = mx.expand_dims(v, 2)
scores = q @ mx.swapaxes(k, -1, -2)
if f32softmax:
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(q_dtype)
else:
scores = mx.softmax(scores, axis=-1)
out = scores @ v
if n_repeats > 1:
out = mx.reshape(out, [B, n_q_heads, L, -1])
return out
def mlx_spda_unfused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_unfused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def mlx_spda_fused(q, k, v, scale, transpose):
q_out = q
if transpose:
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
for i in range(N_iter_func):
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
q_out = mlx_sdpa_fused_inner(q_out, k, v, scale)
if transpose:
q_out = mx.transpose(q_out, (0, 2, 1, 3))
mx.eval(q_out)
return q_out
def bench_shape(B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose=True):
shape_q = (
(B, qsl, n_q_heads, head_dim) if transpose else (B, n_q_heads, qsl, head_dim)
)
shape_kv = (
(B, ksl, n_kv_heads, head_dim) if transpose else (B, n_kv_heads, ksl, head_dim)
)
q_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_q).astype(np_dtype)
k_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
v_np = np.random.normal(0.0, 1.0 / math.sqrt(head_dim), shape_kv).astype(np_dtype)
scale = math.sqrt(1.0 / head_dim)
q_mx = mx.array(q_np)
k_mx = mx.array(k_np)
v_mx = mx.array(v_np)
time_mlx_unfused = bench(mlx_spda_unfused, q_mx, k_mx, v_mx, scale, transpose)
time_mlx_fused = bench(mlx_spda_fused, q_mx, k_mx, v_mx, scale, transpose)
if transpose:
q_mx = mx.transpose(q_mx, (0, 2, 1, 3))
k_mx = mx.transpose(k_mx, (0, 2, 1, 3))
v_mx = mx.transpose(v_mx, (0, 2, 1, 3))
o_mlx_fused = mlx_sdpa_fused_inner(q_mx, k_mx, v_mx, scale)
o_mlx_unfused = mlx_sdpa_unfused_inner(q_mx, k_mx, v_mx, scale, f32softmax=True)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not mx.allclose(o_mlx_fused, o_mlx_unfused, atol=atol):
print(
f"Failed at (B: {B}, qsl: {qsl}, ksl: {ksl}, head_dim: {head_dim}, n_qh: {n_q_heads}, n_kvh: {n_kv_heads}) [tpose = {transpose}] with max(|a - b|) = {mx.max(mx.abs(o_mlx_unfused - o_mlx_fused)):3.2e}"
)
return time_mlx_fused, time_mlx_unfused
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.") parser = argparse.ArgumentParser(description="Run gemm benchmarks")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_self_attention_sdpa() dtypes = ("float16", "float32")[:1]
time_self_attention_primitives() transposes = (False,)
# fmt: off
shapes_64 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 32, 32, 64, 32, 32),
( 1, 64, 64, 64, 32, 32),
( 1, 128, 128, 64, 32, 32),
( 1, 256, 256, 64, 32, 32),
( 1, 512, 512, 64, 32, 32),
( 1, 1024, 1024, 64, 32, 32),
( 1, 2048, 2048, 64, 32, 32),
( 1, 4096, 4096, 64, 32, 32),
)
shapes_80 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 80, 32, 32),
( 1, 2048, 2048, 80, 32, 32),
( 1, 4096, 4096, 80, 32, 32),
)
shapes_128 = (
# ( B, qsl, ksl, head_dim, n_qh, n_kvh)
( 1, 1024, 1024, 128, 32, 32),
( 1, 2048, 2048, 128, 32, 32),
( 1, 4096, 4096, 128, 32, 32),
)
# fmt: on
shapes = shapes_64 + shapes_80 + shapes_128
print(" B, qsl, ksl, hdim, n_qh, n_kvh, tpose, dtype, t_unfs, t_fuse, diff%")
for dtype in dtypes:
for transpose in transposes:
for B, qsl, ksl, head_dim, n_q_heads, n_kv_heads in shapes:
np_dtype = getattr(np, dtype)
time_mlx_fused, time_mlx_unfused = bench_shape(
B, qsl, ksl, head_dim, n_q_heads, n_kv_heads, np_dtype, transpose
)
diff = time_mlx_unfused / time_mlx_fused - 1.0
t_str = 1 if transpose else 0
print(
f"{B:3d}, {qsl:5d}, {ksl:5d}, {head_dim:4d}, {n_q_heads:4d}, {n_kv_heads:5d}, {t_str:5d}, {dtype}, {time_mlx_unfused: 2.3f}, {time_mlx_fused: 2.3f}, {100. * diff:+5.2f}%"
)

View File

@@ -4,42 +4,51 @@ import math
import mlx.core as mx import mlx.core as mx
from time_utils import time_fn from time_utils import time_fn
L = 1024 L = 16384
H = 32 H = 32
H_k = 32 // 4 H_k = H // 4
D = 128 D = 128
dtype = mx.float16
loops = 10
def attention(q, k, v): def attention(q, k, v):
B, Hq, L, D = q.shape def _sdpa(q, k, v):
_, Hk, S, _ = k.shape B, Hq, L, D = q.shape
q = q.reshape(B, Hk, Hq // Hk, L, D) _, Hk, S, _ = k.shape
k = k[:, :, None, :, :] q = q.reshape(B, Hk, Hq // Hk, L, D)
v = v[:, :, None, :, :] k = k[:, :, None, :, :]
s = q @ k.transpose(0, 1, 2, 4, 3) v = v[:, :, None, :, :]
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype) s = q @ k.transpose(0, 1, 2, 4, 3)
o = p @ v p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
return o.reshape(B, Hq, L, D) o = p @ v
return o.reshape(B, Hq, L, D)
for i in range(loops):
q = _sdpa(q, k, v)
return q
def sdpa(q, k, v): def sdpa(q, k, v):
return mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0) for i in range(loops):
q = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0)
return q
def time_self_attention_primitives(): def time_self_attention_primitives():
mx.random.seed(3) mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)) v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v) mx.eval(q, k, v)
time_fn(attention, q, k, v) time_fn(attention, q, k, v)
def time_self_attention_sdpa(): def time_self_attention_sdpa():
mx.random.seed(3) mx.random.seed(3)
q = mx.random.uniform(shape=(1, H, 1, D)) q = mx.random.uniform(shape=(1, H, 1, D)).astype(dtype)
k = mx.random.uniform(shape=(1, H_k, L, D)) k = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
v = mx.random.uniform(shape=(1, H_k, L, D)) v = mx.random.uniform(shape=(1, H_k, L, D)).astype(dtype)
mx.eval(q, k, v) mx.eval(q, k, v)
time_fn(sdpa, q, k, v) time_fn(sdpa, q, k, v)

View File

@@ -420,8 +420,8 @@ element in the output.
constant const float& alpha [[buffer(3)]], constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]], constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]], constant const int* shape [[buffer(5)]],
constant const size_t* x_strides [[buffer(6)]], constant const int64_t* x_strides [[buffer(6)]],
constant const size_t* y_strides [[buffer(7)]], constant const int64_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]], constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array // Convert linear indices to offsets in array
@@ -438,24 +438,10 @@ each instantiation a unique host name so we can identify it.
.. code-block:: C++ .. code-block:: C++
#define instantiate_axpby(type_name, type) \ instantiate_kernel("axpby_general_float32", axpby_general, float)
template [[host_name("axpby_general_" #type_name)]] \ instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
[[kernel]] void axpby_general<type>( \ instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
device const type* x [[buffer(0)]], \ instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
The logic to determine the kernel, set the inputs, resolve the grid dimensions, The logic to determine the kernel, set the inputs, resolve the grid dimensions,
and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
@@ -494,7 +480,7 @@ below.
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal // those in the kernel declaration at axpby.metal
@@ -509,14 +495,14 @@ below.
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder.set_bytes(alpha_, 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4); compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim // Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); compute_encoder.set_bytes(y.strides(), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8); compute_encoder.set_bytes(ndim, 8);
// We launch 1 thread for each input and make sure that the number of // We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed // threads in any given threadgroup is not higher than the max allowed
@@ -530,7 +516,7 @@ below.
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
We can now call the :meth:`axpby` operation on both the CPU and the GPU! We can now call the :meth:`axpby` operation on both the CPU and the GPU!

View File

@@ -209,7 +209,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can on a given machine. Note run-time compilation incurs a cold-start cost which can
be anwywhere from a few hundred millisecond to a few seconds depending on the be anwywhere from a few hundred millisecond to a few seconds depending on the
application. Once a kernel is compiled, it will be cached by the system. The application. Once a kernel is compiled, it will be cached by the system. The
Metal kernel cache persists accross reboots. Metal kernel cache persists across reboots.
Troubleshooting Troubleshooting
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^

View File

@@ -12,5 +12,4 @@ Fast
layer_norm layer_norm
rope rope
scaled_dot_product_attention scaled_dot_product_attention
affine_quantize
metal_kernel metal_kernel

View File

@@ -12,6 +12,7 @@ Layers
ALiBi ALiBi
AvgPool1d AvgPool1d
AvgPool2d AvgPool2d
AvgPool3d
BatchNorm BatchNorm
CELU CELU
Conv1d Conv1d
@@ -41,6 +42,7 @@ Layers
LSTM LSTM
MaxPool1d MaxPool1d
MaxPool2d MaxPool2d
MaxPool3d
Mish Mish
MultiHeadAttention MultiHeadAttention
PReLU PReLU

View File

@@ -184,8 +184,8 @@ Let's time these two different versions:
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100)) print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
vectorized version takes only ``0.025`` seconds, more than ten times faster. vectorized version takes only ``0.024`` seconds, more than 200 times faster.
Of course, this operation is quite contrived. A better approach is to simply do Of course, this operation is quite contrived. A better approach is to simply do
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy. ``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.

View File

@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
// Prepare to encode kernel // Prepare to encode kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Kernel parameters are registered with buffer indices corresponding to // Kernel parameters are registered with buffer indices corresponding to
// those in the kernel declaration at axpby.metal // those in the kernel declaration at axpby.metal
@@ -272,15 +272,15 @@ void Axpby::eval_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode alpha and beta // Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3); compute_encoder.set_bytes(alpha_, 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4); compute_encoder.set_bytes(beta_, 4);
// Encode shape, strides and ndim if needed // Encode shape, strides and ndim if needed
if (!contiguous_kernel) { if (!contiguous_kernel) {
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5); compute_encoder.set_vector_bytes(x.shape(), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6); compute_encoder.set_vector_bytes(x.strides(), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7); compute_encoder.set_bytes(y.strides(), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8); compute_encoder.set_bytes(ndim, 8);
} }
// We launch 1 thread for each input and make sure that the number of // We launch 1 thread for each input and make sure that the number of
@@ -295,7 +295,7 @@ void Axpby::eval_gpu(
// Launch the grid with the given number of threads divided among // Launch the grid with the given number of threads divided among
// the given threadgroups // the given threadgroups
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
#else // Metal is not available #else // Metal is not available

View File

@@ -2,7 +2,6 @@
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T> template <typename T>
@@ -13,8 +12,8 @@ template <typename T>
constant const float& alpha [[buffer(3)]], constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]], constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]], constant const int* shape [[buffer(5)]],
constant const size_t* x_strides [[buffer(6)]], constant const int64_t* x_strides [[buffer(6)]],
constant const size_t* y_strides [[buffer(7)]], constant const int64_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]], constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto x_offset = elem_to_loc(index, shape, x_strides, ndim); auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
@@ -35,29 +34,14 @@ template <typename T>
static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index]; static_cast<T>(alpha) * x[index] + static_cast<T>(beta) * y[index];
} }
#define instantiate_axpby(type_name, type) \ // clang-format off
template [[host_name("axpby_general_" #type_name)]] [[kernel]] void \ #define instantiate_axpby(type_name, type) \
axpby_general<type>( \ instantiate_kernel("axpby_general_" #type_name, axpby_general, type) \
device const type* x [[buffer(0)]], \ instantiate_kernel( \
device const type* y [[buffer(1)]], \ "axpby_contiguous_" #type_name, axpby_contiguous, type)
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("axpby_contiguous_" #type_name)]] [[kernel]] void \
axpby_contiguous<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float); instantiate_axpby(float32, float);
instantiate_axpby(float16, half); instantiate_axpby(float16, half);
instantiate_axpby(bfloat16, bfloat16_t); instantiate_axpby(bfloat16, bfloat16_t);
instantiate_axpby(complex64, complex64_t); instantiate_axpby(complex64, complex64_t);
// clang-format on

View File

@@ -1,4 +1,4 @@
setuptools>=42 setuptools>=42
cmake>=3.24 cmake>=3.24
mlx>=0.18.1 mlx>=0.21.0
nanobind==2.2.0 nanobind==2.2.0

View File

@@ -28,10 +28,19 @@ endif()
if (@MLX_BUILD_METAL@) if (@MLX_BUILD_METAL@)
set(MLX_BUILD_METAL @MLX_BUILD_METAL@) set(MLX_BUILD_METAL @MLX_BUILD_METAL@)
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_) set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
set_and_check(MLX_INCLUDE_DIRS set(MLX_INCLUDE_DIRS
${MLX_INCLUDE_DIRS} "${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp
) )
if(@MLX_METAL_VERSION@ GREATER_EQUAL 310)
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_1)
else()
set(MLX_INCLUDE_DIRS
"${MLX_INCLUDE_DIRS};"
@PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/mlx/backend/metal/kernels/metal_3_0)
endif()
endif() endif()
set_target_properties(mlx PROPERTIES set_target_properties(mlx PROPERTIES
@@ -40,4 +49,4 @@ set_target_properties(mlx PROPERTIES
) )
include(FindPackageHandleStandardArgs) include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS) find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)

View File

@@ -19,7 +19,7 @@ Buffer malloc(size_t size) {
} }
void free(Buffer buffer) { void free(Buffer buffer) {
return allocator().free(buffer); allocator().free(buffer);
} }
Buffer CommonAllocator::malloc(size_t size, bool) { Buffer CommonAllocator::malloc(size_t size, bool) {

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <functional> #include <functional>
#include <unordered_map>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/ops.h" #include "mlx/ops.h"
@@ -30,7 +31,7 @@ array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
} }
array::array( array::array(
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs) std::vector<array> inputs)
@@ -41,7 +42,7 @@ array::array(
std::move(inputs))) {} std::move(inputs))) {}
std::vector<array> array::make_arrays( std::vector<array> array::make_arrays(
std::vector<std::vector<int>> shapes, std::vector<Shape> shapes,
const std::vector<Dtype>& dtypes, const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs) { const std::vector<array>& inputs) {
@@ -73,11 +74,7 @@ array::array(std::initializer_list<int> data, Dtype dtype)
} }
/* Build an array from a shared buffer */ /* Build an array from a shared buffer */
array::array( array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
allocator::Buffer data,
std::vector<int> shape,
Dtype dtype,
deleter_t deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, deleter); set_data(data, deleter);
} }
@@ -125,7 +122,7 @@ bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing() || retain_graph(); return array_desc_->is_tracer && in_tracing() || retain_graph();
} }
void array::set_data(allocator::Buffer buffer, deleter_t d) { void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d); array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr(); array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = size(); array_desc_->data_size = size();
@@ -138,9 +135,9 @@ void array::set_data(allocator::Buffer buffer, deleter_t d) {
void array::set_data( void array::set_data(
allocator::Buffer buffer, allocator::Buffer buffer,
size_t data_size, size_t data_size,
std::vector<size_t> strides, Strides strides,
Flags flags, Flags flags,
deleter_t d) { Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d); array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr(); array_desc_->data_ptr = buffer.raw_ptr();
array_desc_->data_size = data_size; array_desc_->data_size = data_size;
@@ -150,7 +147,7 @@ void array::set_data(
void array::copy_shared_buffer( void array::copy_shared_buffer(
const array& other, const array& other,
const std::vector<size_t>& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset /* = 0 */) { size_t offset /* = 0 */) {
@@ -169,7 +166,7 @@ void array::copy_shared_buffer(const array& other) {
void array::move_shared_buffer( void array::move_shared_buffer(
array other, array other,
const std::vector<size_t>& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset /* = 0 */) { size_t offset /* = 0 */) {
@@ -214,6 +211,8 @@ array::~array() {
if (do_detach) { if (do_detach) {
for (auto& s : siblings()) { for (auto& s : siblings()) {
for (auto& ss : s.siblings()) { for (auto& ss : s.siblings()) {
// Set to null here to avoid descending into array destructor
// for siblings
ss.array_desc_ = nullptr; ss.array_desc_ = nullptr;
} }
s.array_desc_->siblings.clear(); s.array_desc_->siblings.clear();
@@ -234,13 +233,13 @@ void array::ArrayDesc::init() {
} }
} }
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype) array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype), status(Status::available) { : shape(std::move(shape)), dtype(dtype), status(Status::available) {
init(); init();
} }
array::ArrayDesc::ArrayDesc( array::ArrayDesc::ArrayDesc(
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs) std::vector<array> inputs)
@@ -292,6 +291,14 @@ array::ArrayDesc::~ArrayDesc() {
auto top = std::move(for_deletion.back()); auto top = std::move(for_deletion.back());
for_deletion.pop_back(); for_deletion.pop_back();
append_deletable_inputs(*top); append_deletable_inputs(*top);
// Clear out possible siblings to break circular references
for (auto& s : top->siblings) {
// Set to null here to avoid descending into top-level
// array destructor for siblings
s.array_desc_ = nullptr;
}
top->siblings.clear();
} }
} }

View File

@@ -15,7 +15,10 @@ namespace mlx::core {
// Forward declaration // Forward declaration
class Primitive; class Primitive;
using deleter_t = std::function<void(allocator::Buffer)>;
using Deleter = std::function<void(allocator::Buffer)>;
using Shape = std::vector<int32_t>;
using Strides = std::vector<int64_t>;
class array { class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc /* An array is really a node in a graph. It contains a shared ArrayDesc
@@ -33,7 +36,7 @@ class array {
template <typename It> template <typename It>
array( array(
It data, It data,
std::vector<int> shape, Shape shape,
Dtype dtype = Dtype dtype =
TypeToDtype<typename std::iterator_traits<It>::value_type>()); TypeToDtype<typename std::iterator_traits<It>::value_type>());
@@ -49,15 +52,15 @@ class array {
template <typename T> template <typename T>
array( array(
std::initializer_list<T> data, std::initializer_list<T> data,
std::vector<int> shape, Shape shape,
Dtype dtype = TypeToDtype<T>()); Dtype dtype = TypeToDtype<T>());
/* Build an array from a buffer */ /* Build an array from a buffer */
array( array(
allocator::Buffer data, allocator::Buffer data,
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
deleter_t deleter = allocator::free); Deleter deleter = allocator::free);
/** Assignment to rvalue does not compile. */ /** Assignment to rvalue does not compile. */
array& operator=(const array& other) && = delete; array& operator=(const array& other) && = delete;
@@ -96,7 +99,7 @@ class array {
} }
/** The shape of the array as a vector of integers. */ /** The shape of the array as a vector of integers. */
const std::vector<int>& shape() const { const Shape& shape() const {
return array_desc_->shape; return array_desc_->shape;
} }
@@ -105,12 +108,12 @@ class array {
* *
* This function supports negative indexing and provides * This function supports negative indexing and provides
* bounds checking. */ * bounds checking. */
int shape(int dim) const { auto shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim); return shape().at(dim < 0 ? dim + ndim() : dim);
} }
/** The strides of the array. */ /** The strides of the array. */
const std::vector<size_t>& strides() const { const Strides& strides() const {
return array_desc_->strides; return array_desc_->strides;
} }
@@ -119,7 +122,7 @@ class array {
* *
* This function supports negative indexing and provides * This function supports negative indexing and provides
* bounds checking. */ * bounds checking. */
size_t strides(int dim) const { auto strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim); return strides().at(dim < 0 ? dim + ndim() : dim);
} }
@@ -184,13 +187,13 @@ class array {
*/ */
array( array(
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs); std::vector<array> inputs);
static std::vector<array> make_arrays( static std::vector<array> make_arrays(
std::vector<std::vector<int>> shapes, std::vector<Shape> shapes,
const std::vector<Dtype>& dtypes, const std::vector<Dtype>& dtypes,
const std::shared_ptr<Primitive>& primitive, const std::shared_ptr<Primitive>& primitive,
const std::vector<array>& inputs); const std::vector<array>& inputs);
@@ -207,8 +210,8 @@ class array {
struct Data { struct Data {
allocator::Buffer buffer; allocator::Buffer buffer;
deleter_t d; Deleter d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free) Data(allocator::Buffer buffer, Deleter d = allocator::free)
: buffer(buffer), d(d) {} : buffer(buffer), d(d) {}
// Not copyable // Not copyable
Data(const Data& d) = delete; Data(const Data& d) = delete;
@@ -397,18 +400,18 @@ class array {
// Check if the array is a tracer array // Check if the array is a tracer array
bool is_tracer() const; bool is_tracer() const;
void set_data(allocator::Buffer buffer, deleter_t d = allocator::free); void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
void set_data( void set_data(
allocator::Buffer buffer, allocator::Buffer buffer,
size_t data_size, size_t data_size,
std::vector<size_t> strides, Strides strides,
Flags flags, Flags flags,
deleter_t d = allocator::free); Deleter d = allocator::free);
void copy_shared_buffer( void copy_shared_buffer(
const array& other, const array& other,
const std::vector<size_t>& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset = 0); size_t offset = 0);
@@ -417,7 +420,7 @@ class array {
void move_shared_buffer( void move_shared_buffer(
array other, array other,
const std::vector<size_t>& strides, const Strides& strides,
Flags flags, Flags flags,
size_t data_size, size_t data_size,
size_t offset = 0); size_t offset = 0);
@@ -436,8 +439,8 @@ class array {
void init(const It src); void init(const It src);
struct ArrayDesc { struct ArrayDesc {
std::vector<int> shape; Shape shape;
std::vector<size_t> strides; Strides strides;
size_t size; size_t size;
Dtype dtype; Dtype dtype;
std::shared_ptr<Primitive> primitive; std::shared_ptr<Primitive> primitive;
@@ -471,10 +474,10 @@ class array {
// The arrays position in the output list // The arrays position in the output list
uint32_t position{0}; uint32_t position{0};
explicit ArrayDesc(std::vector<int> shape, Dtype dtype); explicit ArrayDesc(Shape shape, Dtype dtype);
explicit ArrayDesc( explicit ArrayDesc(
std::vector<int> shape, Shape shape,
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs); std::vector<array> inputs);
@@ -502,7 +505,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype<T>() */)
template <typename It> template <typename It>
array::array( array::array(
It data, It data,
std::vector<int> shape, Shape shape,
Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) : Dtype dtype /* = TypeToDtype<typename std::iterator_traits<It>::value_type>() */) :
array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
init(data); init(data);
@@ -521,7 +524,7 @@ array::array(
template <typename T> template <typename T>
array::array( array::array(
std::initializer_list<T> data, std::initializer_list<T> data,
std::vector<int> shape, Shape shape,
Dtype dtype /* = TypeToDtype<T>() */) Dtype dtype /* = TypeToDtype<T>() */)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) { : array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
if (data.size() != size()) { if (data.size() != size()) {

View File

@@ -13,8 +13,8 @@ template <typename InT, typename OpT>
void arg_reduce(const array& in, array& out, const OpT& op, int axis) { void arg_reduce(const array& in, array& out, const OpT& op, int axis) {
auto axis_size = in.shape()[axis]; auto axis_size = in.shape()[axis];
auto axis_stride = in.strides()[axis]; auto axis_stride = in.strides()[axis];
std::vector<size_t> strides = in.strides(); Strides strides = in.strides();
std::vector<int> shape = in.shape(); Shape shape = in.shape();
strides.erase(strides.begin() + axis); strides.erase(strides.begin() + axis);
shape.erase(shape.begin() + axis); shape.erase(shape.begin() + axis);
for (uint32_t i = 0; i < out.size(); ++i) { for (uint32_t i = 0; i < out.size(); ++i) {

View File

@@ -178,10 +178,10 @@ void binary_op_dims(
const T* b, const T* b,
U* out, U* out,
Op op, Op op,
const std::vector<int>& shape, const Shape& shape,
const std::vector<size_t>& a_strides, const Strides& a_strides,
const std::vector<size_t>& b_strides, const Strides& b_strides,
const std::vector<size_t>& out_strides, const Strides& out_strides,
int axis) { int axis) {
auto stride_a = a_strides[axis]; auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis]; auto stride_b = b_strides[axis];
@@ -212,10 +212,10 @@ void binary_op_dispatch_dims(
array& out, array& out,
Op op, Op op,
int dim, int dim,
const std::vector<int>& shape, const Shape& shape,
const std::vector<size_t>& a_strides, const Strides& a_strides,
const std::vector<size_t>& b_strides, const Strides& b_strides,
const std::vector<size_t>& out_strides) { const Strides& out_strides) {
const T* a_ptr = a.data<T>(); const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>(); const T* b_ptr = b.data<T>();
U* out_ptr = out.data<U>(); U* out_ptr = out.data<U>();
@@ -258,10 +258,10 @@ void binary_op_dispatch_dims(
return; return;
} }
ContiguousIterator<size_t> a_it(shape, a_strides, dim - 3); ContiguousIterator a_it(shape, a_strides, dim - 3);
ContiguousIterator<size_t> b_it(shape, b_strides, dim - 3); ContiguousIterator b_it(shape, b_strides, dim - 3);
size_t stride = out_strides[dim - 4]; auto stride = out_strides[dim - 4];
for (size_t elem = 0; elem < a.size(); elem += stride) { for (int64_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 3, Strided>( binary_op_dims<T, U, Op, 3, Strided>(
a_ptr + a_it.loc, a_ptr + a_it.loc,
b_ptr + b_it.loc, b_ptr + b_it.loc,
@@ -327,7 +327,7 @@ void binary_op(
const auto& strides = new_strides[2]; const auto& strides = new_strides[2];
// Get the left-most dim such that the array is row contiguous after // Get the left-most dim such that the array is row contiguous after
auto leftmost_rc_dim = [&strides](const std::vector<size_t>& arr_strides) { auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
int d = arr_strides.size() - 1; int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == strides[d]; d--) { for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
} }
@@ -337,7 +337,7 @@ void binary_op(
auto b_rc_dim = leftmost_rc_dim(b_strides); auto b_rc_dim = leftmost_rc_dim(b_strides);
// Get the left-most dim such that the array is a broadcasted "scalar" after // Get the left-most dim such that the array is a broadcasted "scalar" after
auto leftmost_s_dim = [](const std::vector<size_t>& arr_strides) { auto leftmost_s_dim = [](const auto& arr_strides) {
int d = arr_strides.size() - 1; int d = arr_strides.size() - 1;
for (; d >= 0 && arr_strides[d] == 0; d--) { for (; d >= 0 && arr_strides[d] == 0; d--) {
} }

View File

@@ -16,10 +16,10 @@ void binary_op_dims(
U* out_a, U* out_a,
U* out_b, U* out_b,
Op op, Op op,
const std::vector<int>& shape, const Shape& shape,
const std::vector<size_t>& a_strides, const Strides& a_strides,
const std::vector<size_t>& b_strides, const Strides& b_strides,
const std::vector<size_t>& out_strides, const Strides& out_strides,
int axis) { int axis) {
auto stride_a = a_strides[axis]; auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis]; auto stride_b = b_strides[axis];
@@ -96,9 +96,9 @@ void binary_op_dispatch_dims(
return; return;
} }
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2); ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2); ContiguousIterator b_it(shape, b_strides, ndim - 2);
size_t stride = out_strides[ndim - 3]; auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) { for (size_t elem = 0; elem < a.size(); elem += stride) {
binary_op_dims<T, U, Op, 2>( binary_op_dims<T, U, Op, 2>(
a_ptr + a_it.loc, a_ptr + a_it.loc,

View File

@@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
// rely on data_size anyway. // rely on data_size anyway.
size_t data_size = out.size(); size_t data_size = out.size();
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); return move_or_copy(in, out, strides_, flags, data_size, offset_);
} }
void Broadcast::eval(const std::vector<array>& inputs, array& out) { void Broadcast::eval(const std::vector<array>& inputs, array& out) {
@@ -49,7 +49,7 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
out.set_data(nullptr); out.set_data(nullptr);
return; return;
} }
std::vector<size_t> strides(out.ndim(), 0); Strides strides(out.ndim(), 0);
int diff = out.ndim() - in.ndim(); int diff = out.ndim() - in.ndim();
for (int i = in.ndim() - 1; i >= 0; --i) { for (int i = in.ndim() - 1; i >= 0; --i) {
strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i];
@@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
if (out.size() > in.size()) { if (out.size() > in.size()) {
flags.row_contiguous = flags.col_contiguous = false; flags.row_contiguous = flags.col_contiguous = false;
} }
out.copy_shared_buffer(in, strides, flags, in.data_size()); move_or_copy(in, out, strides, flags, in.data_size());
} }
void Copy::eval(const std::vector<array>& inputs, array& out) { void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]); move_or_copy(inputs[0], out);
} }
void CustomTransforms::eval( void CustomTransforms::eval(
@@ -72,7 +72,7 @@ void CustomTransforms::eval(
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
i++, j++) { i++, j++) {
outputs[i].copy_shared_buffer(inputs[j]); move_or_copy(inputs[j], outputs[i]);
} }
} }
@@ -81,7 +81,7 @@ void Depends::eval(
std::vector<array>& outputs) { std::vector<array>& outputs) {
assert(inputs.size() > outputs.size()); assert(inputs.size() > outputs.size());
for (int i = 0; i < outputs.size(); i++) { for (int i = 0; i < outputs.size(); i++) {
outputs[i].copy_shared_buffer(inputs[i]); move_or_copy(inputs[i], outputs[i]);
} }
} }
@@ -141,7 +141,7 @@ void NumberOfElements::eval(const std::vector<array>& inputs, array& out) {
} }
} }
std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape( std::pair<bool, Strides> Reshape::prepare_reshape(
const array& in, const array& in,
const array& out) { const array& out) {
// Special case for empty arrays or row contiguous arrays // Special case for empty arrays or row contiguous arrays
@@ -151,8 +151,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
// Special case for scalars // Special case for scalars
if (in.ndim() == 0) { if (in.ndim() == 0) {
std::vector<size_t> out_strides(out.ndim(), 0); return {false, Strides(out.ndim(), 0)};
return {false, out_strides};
} }
// Firstly let's collapse all the contiguous dimensions of the input // Firstly let's collapse all the contiguous dimensions of the input
@@ -160,7 +159,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
// If shapes fit exactly in the contiguous dims then no copy is necessary so // If shapes fit exactly in the contiguous dims then no copy is necessary so
// let's check. // let's check.
std::vector<size_t> out_strides; Strides out_strides;
bool copy_necessary = false; bool copy_necessary = false;
int j = 0; int j = 0;
for (int i = 0; i < out.ndim(); i++) { for (int i = 0; i < out.ndim(); i++) {
@@ -183,7 +182,7 @@ std::pair<bool, std::vector<size_t>> Reshape::prepare_reshape(
void Reshape::shared_buffer_reshape( void Reshape::shared_buffer_reshape(
const array& in, const array& in,
const std::vector<size_t>& out_strides, const Strides& out_strides,
array& out) { array& out) {
auto flags = in.flags(); auto flags = in.flags();
if (flags.row_contiguous) { if (flags.row_contiguous) {
@@ -194,7 +193,7 @@ void Reshape::shared_buffer_reshape(
auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
} }
out.copy_shared_buffer(in, out_strides, flags, in.data_size()); move_or_copy(in, out, out_strides, flags, in.data_size());
} }
void Split::eval( void Split::eval(
@@ -249,26 +248,14 @@ void Split::eval(
} }
} }
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
}
return std::make_tuple(data_offset, inp_strides);
}
void StopGradient::eval(const std::vector<array>& inputs, array& out) { void StopGradient::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
out.copy_shared_buffer(inputs[0]); move_or_copy(inputs[0], out);
} }
void Transpose::eval(const std::vector<array>& inputs, array& out) { void Transpose::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
std::vector<size_t> out_strides(out.ndim()); Strides out_strides(out.ndim());
auto& in = inputs[0]; auto& in = inputs[0];
for (int ax = 0; ax < axes_.size(); ++ax) { for (int ax = 0; ax < axes_.size(); ++ax) {
out_strides[ax] = in.strides()[axes_[ax]]; out_strides[ax] = in.strides()[axes_[ax]];
@@ -285,8 +272,8 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
// true, they stay true) // true, they stay true)
auto flags = in.flags(); auto flags = in.flags();
if (flags.contiguous && in.data_size() == in.size()) { if (flags.contiguous && in.data_size() == in.size()) {
size_t f_stride = 1; int64_t f_stride = 1;
size_t b_stride = 1; int64_t b_stride = 1;
flags.col_contiguous = true; flags.col_contiguous = true;
flags.row_contiguous = true; flags.row_contiguous = true;
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) { for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
@@ -297,7 +284,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
b_stride *= out.shape(ri); b_stride *= out.shape(ri);
} }
} }
out.copy_shared_buffer(in, out_strides, flags, in.data_size()); move_or_copy(in, out, out_strides, flags, in.data_size());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -165,7 +165,7 @@ void compiled_allocate_outputs(
bool move_buffers /* = false */) { bool move_buffers /* = false */) {
if (contiguous) { if (contiguous) {
int o = 0; int o = 0;
std::vector<size_t> strides; Strides strides;
size_t data_size; size_t data_size;
array::Flags flags; array::Flags flags;
for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) { for (int i = 0; i < inputs.size() && o < outputs.size(); ++i) {

View File

@@ -47,7 +47,7 @@ bool compile_available_for_device(const Device& device) {
} // namespace detail } // namespace detail
std::string get_temp_file(const std::string& name) { std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name); return std::filesystem::temp_directory_path().append(name).string();
} }
// Return a pointer to a compiled function // Return a pointer to a compiled function
@@ -279,7 +279,7 @@ void Compiled::eval_cpu(
// Figure out which kernel we are using // Figure out which kernel we are using
auto& shape = outputs[0].shape(); auto& shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, shape); auto contiguous = compiled_check_contiguity(inputs, shape);
// Handle all broadcasting and collect function input arguments // Handle all broadcasting and collect function input arguments
std::vector<void*> args; std::vector<void*> args;

View File

@@ -746,9 +746,9 @@ void explicit_gemm_conv_1D_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view // Make strided view
std::vector<int> strided_shape = {N, oH, wH, C}; Shape strided_shape = {N, oH, wH, C};
std::vector<size_t> strided_strides = { Strides strided_strides = {
in_padded.strides()[0], in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0], in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[1], in_padded.strides()[1],
@@ -865,9 +865,9 @@ void explicit_gemm_conv_2D_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view // Make strided view
std::vector<int> strided_shape = {N, oH, oW, wH, wW, C}; Shape strided_shape = {N, oH, oW, wH, wW, C};
std::vector<size_t> strided_strides = { Strides strided_strides = {
in_padded.strides()[0], in_padded.strides()[0],
in_padded.strides()[1] * wt_strides[0], in_padded.strides()[1] * wt_strides[0],
in_padded.strides()[2] * wt_strides[1], in_padded.strides()[2] * wt_strides[1],
@@ -974,7 +974,7 @@ void explicit_gemm_conv_ND_cpu(
copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral);
// Make strided view // Make strided view
std::vector<int> strided_shape(oDim.size() + wDim.size() + 2); Shape strided_shape(oDim.size() + wDim.size() + 2);
strided_shape.front() = N; strided_shape.front() = N;
for (size_t i = 0; i < oDim.size(); i++) { for (size_t i = 0; i < oDim.size(); i++) {
strided_shape[i + 1] = oDim[i]; strided_shape[i + 1] = oDim[i];
@@ -984,7 +984,7 @@ void explicit_gemm_conv_ND_cpu(
} }
strided_shape.back() = C; strided_shape.back() = C;
std::vector<size_t> strided_strides(in.shape().size() * 2 - 2); Strides strided_strides(in.shape().size() * 2 - 2);
strided_strides[0] = in_padded.strides()[0]; strided_strides[0] = in_padded.strides()[0];
for (size_t i = 0; i < wt_strides.size(); i++) { for (size_t i = 0; i < wt_strides.size(); i++) {
strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i]; strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
@@ -1000,7 +1000,7 @@ void explicit_gemm_conv_ND_cpu(
in_padded, strided_strides, flags, in_strided_view.size(), 0); in_padded, strided_strides, flags, in_strided_view.size(), 0);
// Materialize strided view // Materialize strided view
std::vector<int> strided_reshape = {N, C}; Shape strided_reshape = {N, C};
for (const auto& o : oDim) { for (const auto& o : oDim) {
strided_reshape[0] *= o; strided_reshape[0] *= o;
} }

View File

@@ -26,13 +26,13 @@ void copy_vector(const array& src, array& dst) {
std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr);
} }
template <typename SrcT, typename DstT, typename StrideT, int D> template <typename SrcT, typename DstT, int D>
inline void copy_dims( inline void copy_dims(
const SrcT* src, const SrcT* src,
DstT* dst, DstT* dst,
const std::vector<int>& shape, const Shape& shape,
const std::vector<StrideT>& i_strides, const Strides& i_strides,
const std::vector<StrideT>& o_strides, const Strides& o_strides,
int axis) { int axis) {
auto stride_src = i_strides[axis]; auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis]; auto stride_dst = o_strides[axis];
@@ -40,7 +40,7 @@ inline void copy_dims(
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if constexpr (D > 1) { if constexpr (D > 1) {
copy_dims<SrcT, DstT, StrideT, D - 1>( copy_dims<SrcT, DstT, D - 1>(
src, dst, shape, i_strides, o_strides, axis + 1); src, dst, shape, i_strides, o_strides, axis + 1);
} else { } else {
*dst = static_cast<DstT>(*src); *dst = static_cast<DstT>(*src);
@@ -50,13 +50,13 @@ inline void copy_dims(
} }
} }
template <typename SrcT, typename DstT, typename StrideT> template <typename SrcT, typename DstT>
void copy_general_general( void copy_general_general(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const Shape& data_shape,
const std::vector<StrideT>& i_strides, const Strides& i_strides,
const std::vector<StrideT>& o_strides, const Strides& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset) { int64_t o_offset) {
if (data_shape.empty()) { if (data_shape.empty()) {
@@ -65,30 +65,30 @@ void copy_general_general(
*dst_ptr = val; *dst_ptr = val;
return; return;
} }
auto [shape, strides] = collapse_contiguous_dims( auto [shape, strides] =
data_shape, std::vector<std::vector<StrideT>>{i_strides, o_strides}); collapse_contiguous_dims(data_shape, {i_strides, o_strides});
auto src_ptr = src.data<SrcT>() + i_offset; auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>() + o_offset; auto dst_ptr = dst.data<DstT>() + o_offset;
int ndim = shape.size(); int ndim = shape.size();
if (ndim == 1) { if (ndim == 1) {
copy_dims<SrcT, DstT, StrideT, 1>( copy_dims<SrcT, DstT, 1>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0); src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return; return;
} else if (ndim == 2) { } else if (ndim == 2) {
copy_dims<SrcT, DstT, StrideT, 2>( copy_dims<SrcT, DstT, 2>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0); src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return; return;
} else if (ndim == 3) { } else if (ndim == 3) {
copy_dims<SrcT, DstT, StrideT, 3>( copy_dims<SrcT, DstT, 3>(
src_ptr, dst_ptr, shape, strides[0], strides[1], 0); src_ptr, dst_ptr, shape, strides[0], strides[1], 0);
return; return;
} }
ContiguousIterator<StrideT> in(shape, strides[0], ndim - 3); ContiguousIterator in(shape, strides[0], ndim - 3);
ContiguousIterator<StrideT> out(shape, strides[1], ndim - 3); ContiguousIterator out(shape, strides[1], ndim - 3);
StrideT stride = std::accumulate( auto stride = std::accumulate(
shape.end() - 3, shape.end(), 1, std::multiplies<StrideT>()); shape.end() - 3, shape.end(), 1, std::multiplies<int64_t>());
for (StrideT elem = 0; elem < src.size(); elem += stride) { for (int64_t elem = 0; elem < src.size(); elem += stride) {
copy_dims<SrcT, DstT, StrideT, 3>( copy_dims<SrcT, DstT, 3>(
src_ptr + in.loc, src_ptr + in.loc,
dst_ptr + out.loc, dst_ptr + out.loc,
shape, shape,
@@ -102,37 +102,37 @@ void copy_general_general(
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general_general(const array& src, array& dst) { inline void copy_general_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>( copy_general_general<SrcT, DstT>(
src, dst, src.shape(), src.strides(), dst.strides(), 0, 0); src, dst, src.shape(), src.strides(), dst.strides(), 0, 0);
} }
template <typename SrcT, typename DstT, typename StrideT> template <typename SrcT, typename DstT>
void copy_general( void copy_general(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const Shape& data_shape,
const std::vector<StrideT>& i_strides, const Strides& i_strides,
const std::vector<StrideT>&, const Strides&,
int64_t i_offset, int64_t i_offset,
int64_t o_offset) { int64_t o_offset) {
copy_general_general<SrcT, DstT, StrideT>( copy_general_general<SrcT, DstT>(
src, src,
dst, dst,
data_shape, data_shape,
i_strides, i_strides,
make_contiguous_strides<StrideT>(data_shape), make_contiguous_strides(data_shape),
i_offset, i_offset,
o_offset); o_offset);
} }
template <typename SrcT, typename DstT> template <typename SrcT, typename DstT>
inline void copy_general(const array& src, array& dst) { inline void copy_general(const array& src, array& dst) {
copy_general_general<SrcT, DstT, size_t>( copy_general_general<SrcT, DstT>(
src, src,
dst, dst,
src.shape(), src.shape(),
src.strides(), src.strides(),
make_contiguous_strides<size_t>(src.shape()), make_contiguous_strides(src.shape()),
0, 0,
0); 0);
} }
@@ -282,13 +282,12 @@ void copy(const array& src, array& dst, CopyType ctype) {
copy_inplace(src, dst, ctype); copy_inplace(src, dst, ctype);
} }
template <typename StrideT>
void copy_inplace( void copy_inplace(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const Shape& data_shape,
const std::vector<StrideT>& i_strides, const Strides& i_strides,
const std::vector<StrideT>& o_strides, const Strides& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype) { CopyType ctype) {
@@ -311,24 +310,4 @@ void copy_inplace(
} }
} }
template void copy_inplace<size_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<int64_t>& i_strides,
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -26,13 +26,12 @@ enum class CopyType {
void copy(const array& src, array& dst, CopyType ctype); void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype); void copy_inplace(const array& src, array& dst, CopyType ctype);
template <typename stride_t>
void copy_inplace( void copy_inplace(
const array& src, const array& src,
array& dst, array& dst,
const std::vector<int>& data_shape, const Shape& data_shape,
const std::vector<stride_t>& i_strides, const Strides& i_strides,
const std::vector<stride_t>& o_strides, const Strides& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype); CopyType ctype);

View File

@@ -130,7 +130,7 @@ inline void matmul_common_general(
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General); copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1); stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy); return std::make_tuple(false, stx, arr_copy);
} }
}; };

View File

@@ -32,7 +32,7 @@ void gather(
const std::vector<array>& inds, const std::vector<array>& inds,
array& out, array& out,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<int>& slice_sizes) { const Shape& slice_sizes) {
// If the array is row contiguous then we can do a contiguous copy given // If the array is row contiguous then we can do a contiguous copy given
// two conditions on the slice size: // two conditions on the slice size:
// - Any number of leading ones in the slice sizes are allowed // - Any number of leading ones in the slice sizes are allowed
@@ -80,11 +80,10 @@ void gather(
T* dst_ptr = out.data<T>(); T* dst_ptr = out.data<T>();
size_t out_idx = 0; size_t out_idx = 0;
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end()); std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator<size_t> src_it; ContiguousIterator src_it;
if (!can_copy && src.ndim() > 0) { if (!can_copy && src.ndim() > 0) {
src_it = std::move( src_it = ContiguousIterator(slice_sizes, src.strides(), src.ndim());
ContiguousIterator<size_t>(slice_sizes, src.strides(), src.ndim()));
} }
for (int idx = 0; idx < ind_size; idx++) { for (int idx = 0; idx < ind_size; idx++) {
size_t src_idx = 0; size_t src_idx = 0;
@@ -119,7 +118,7 @@ void dispatch_gather(
const std::vector<array>& inds, const std::vector<array>& inds,
array& out, array& out,
const std::vector<int>& axes, const std::vector<int>& axes,
const std::vector<int>& size) { const Shape& size) {
switch (out.dtype()) { switch (out.dtype()) {
case bool_: case bool_:
gather<bool, IdxT>(src, inds, out, axes, size); gather<bool, IdxT>(src, inds, out, axes, size);
@@ -223,16 +222,16 @@ void scatter(
auto inds_ndim = updates.ndim() - out.ndim(); auto inds_ndim = updates.ndim() - out.ndim();
size_t n_updates = nind ? inds[0].size() : 1; size_t n_updates = nind ? inds[0].size() : 1;
std::vector<int> update_shape( Shape update_shape(
updates.shape().begin() + inds_ndim, updates.shape().end()); updates.shape().begin() + inds_ndim, updates.shape().end());
size_t update_size = 1; size_t update_size = 1;
for (auto us : update_shape) { for (auto us : update_shape) {
update_size *= us; update_size *= us;
} }
std::vector<ContiguousIterator<size_t>> its(inds.begin(), inds.end()); std::vector<ContiguousIterator> its(inds.begin(), inds.end());
ContiguousIterator<size_t> update_it(updates); ContiguousIterator update_it(updates);
ContiguousIterator<size_t> out_it(update_shape, out.strides(), out.ndim()); ContiguousIterator out_it(update_shape, out.strides(), out.ndim());
for (int i = 0; i < n_updates; ++i) { for (int i = 0; i < n_updates; ++i) {
size_t out_offset = 0; size_t out_offset = 0;

View File

@@ -2,6 +2,15 @@
#pragma once #pragma once
// Required for Visual Studio.
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
#ifdef _MSC_VER
#include <complex>
#define LAPACK_COMPLEX_CUSTOM
#define lapack_complex_float std::complex<float>
#define lapack_complex_double std::complex<double>
#endif
#ifdef ACCELERATE_NEW_LAPACK #ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h> #include <Accelerate/Accelerate.h>
#else #else

View File

@@ -19,10 +19,10 @@ inline void mask_matrix(
int block_size, int block_size,
const int X, const int X,
const int Y, const int Y,
const size_t X_data_str, const int64_t X_data_str,
const size_t Y_data_str, const int64_t Y_data_str,
const size_t X_mask_str, const int64_t X_mask_str,
const size_t Y_mask_str, const int64_t Y_mask_str,
const size_t mask_offset) { const size_t mask_offset) {
int tX = (X + block_size - 1) / block_size; int tX = (X + block_size - 1) / block_size;
int tY = (Y + block_size - 1) / block_size; int tY = (Y + block_size - 1) / block_size;
@@ -84,7 +84,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General); copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy); return std::make_tuple(false, stx, arr_copy);
} }
}; };
@@ -117,13 +117,13 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
int Y, int Y,
size_t X_data_str, size_t X_data_str,
size_t Y_data_str) { size_t Y_data_str) {
size_t mask_offset = elem_to_loc( auto mask_offset = elem_to_loc(
mask.shape(-1) * mask.shape(-2) * batch_idx, mask.shape(-1) * mask.shape(-2) * batch_idx,
mask.shape(), mask.shape(),
mask.strides()); mask.strides());
size_t X_mask_str = mask.strides()[mask.ndim() - 2]; auto X_mask_str = mask.strides()[mask.ndim() - 2];
size_t Y_mask_str = mask.strides()[mask.ndim() - 1]; auto Y_mask_str = mask.strides()[mask.ndim() - 1];
if (mask.dtype() == bool_) { if (mask.dtype() == bool_) {
return mask_matrix( return mask_matrix(
@@ -230,7 +230,7 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General); copy(arr, arr_copy, CopyType::General);
size_t stx = arr.shape(-1); int64_t stx = arr.shape(-1);
return std::make_tuple(false, stx, arr_copy); return std::make_tuple(false, stx, arr_copy);
} }
}; };
@@ -262,13 +262,13 @@ void GatherMM::eval(const std::vector<array>& inputs, array& out) {
auto& lhs_indices = inputs[2]; auto& lhs_indices = inputs[2];
auto& rhs_indices = inputs[3]; auto& rhs_indices = inputs[3];
std::vector<int> batch_shape = get_batch_dims(out.shape()); auto batch_shape = get_batch_dims(out.shape());
int batch_ndim = batch_shape.size(); int batch_ndim = batch_shape.size();
std::vector<int> batch_shape_A = get_batch_dims(a.shape()); auto batch_shape_A = get_batch_dims(a.shape());
std::vector<size_t> batch_strides_A = get_batch_dims(a.strides()); auto batch_strides_A = get_batch_dims(a.strides());
std::vector<int> batch_shape_B = get_batch_dims(b.shape()); auto batch_shape_B = get_batch_dims(b.shape());
std::vector<size_t> batch_strides_B = get_batch_dims(b.strides()); auto batch_strides_B = get_batch_dims(b.strides());
const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>(); const uint32_t* lhs_indices_ptr = lhs_indices.data<uint32_t>();
const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>(); const uint32_t* rhs_indices_ptr = rhs_indices.data<uint32_t>();

View File

@@ -500,7 +500,12 @@ struct Equal {
struct NaNEqual { struct NaNEqual {
template <typename T> template <typename T>
bool operator()(T x, T y) { bool operator()(T x, T y) {
return x == y || (std::isnan(x) && std::isnan(y)); if constexpr (std::is_integral_v<T>) {
// isnan always returns false for integers, and MSVC refuses to compile.
return x == y;
} else {
return x == y || (std::isnan(x) && std::isnan(y));
}
} }
}; };

View File

@@ -159,6 +159,17 @@ void Conjugate::eval(const std::vector<array>& inputs, array& out) {
} }
} }
void Contiguous::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
out.copy_shared_buffer(in);
} else {
copy(in, out, CopyType::General);
}
}
void Cos::eval(const std::vector<array>& inputs, array& out) { void Cos::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
@@ -487,14 +498,15 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made // Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] = auto [data_offset, inp_strides] = prepare_slice(in, start_indices_, strides_);
prepare_slice(in, start_indices_, strides_); auto copy_needed = std::any_of(
strides_.begin(), strides_.end(), [](auto i) { return i < 0; });
// Do copy if needed // Do copy if needed
if (copy_needed) { if (copy_needed) {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()}; Strides ostrides{out.strides().begin(), out.strides().end()};
copy_inplace<int64_t>( copy_inplace(
/* const array& src = */ in, /* const array& src = */ in,
/* array& dst = */ out, /* array& dst = */ out,
/* const std::vector<int>& data_shape = */ out.shape(), /* const std::vector<int>& data_shape = */ out.shape(),
@@ -512,7 +524,7 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
} }
} }
size_t data_size = data_end - data_offset; size_t data_size = data_end - data_offset;
std::vector<size_t> ostrides{inp_strides.begin(), inp_strides.end()}; Strides ostrides{inp_strides.begin(), inp_strides.end()};
shared_buffer_slice(in, ostrides, data_offset, data_size, out); shared_buffer_slice(in, ostrides, data_offset, data_size, out);
} }
} }
@@ -539,11 +551,11 @@ void SliceUpdate::eval(const std::vector<array>& inputs, array& out) {
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype); copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype);
// Calculate out strides, initial offset and if copy needs to be made // Calculate out strides, initial offset and if copy needs to be made
auto [data_offset, out_strides] = prepare_slice(out); auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_);
// Do copy // Do copy
std::vector<int64_t> upd_strides{upd.strides().begin(), upd.strides().end()}; Strides upd_strides{upd.strides().begin(), upd.strides().end()};
copy_inplace<int64_t>( copy_inplace(
/* const array& src = */ upd, /* const array& src = */ upd,
/* array& dst = */ out, /* array& dst = */ out,
/* const std::vector<int>& data_shape = */ upd.shape(), /* const std::vector<int>& data_shape = */ upd.shape(),
@@ -606,7 +618,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 || if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) { in.flags().row_contiguous) {
auto strides = in.strides(); auto strides = in.strides();
for (int i = 0; i < strides.size() - 1; ++i) { for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
strides[i] *= ibytes; strides[i] *= ibytes;
strides[i] /= obytes; strides[i] /= obytes;
} }

View File

@@ -54,7 +54,7 @@ void qrf_impl(const array& a, array& q, array& r) {
// Copy the input to be column contiguous // Copy the input to be column contiguous
flags.col_contiguous = num_matrices == 1; flags.col_contiguous = num_matrices == 1;
flags.row_contiguous = false; flags.row_contiguous = false;
std::vector<size_t> strides = in.strides(); auto strides = in.strides();
strides[in.ndim() - 2] = 1; strides[in.ndim() - 2] = 1;
strides[in.ndim() - 1] = M; strides[in.ndim() - 1] = M;
in.set_data( in.set_data(

View File

@@ -2,13 +2,38 @@
#include <cassert> #include <cassert>
#include "mlx/backend/metal/copy.h" #include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
namespace { namespace {
template <typename T, int bits>
void extract_bits(const uint8_t* w_in, T* w_out) {
assert(bits == 3 || bits == 6);
if (bits == 3) {
w_out[0] = static_cast<T>(w_in[0] & 0x7);
w_out[1] = static_cast<T>((w_in[0] & 0x38) >> 3);
w_out[2] = static_cast<T>(((w_in[0] & 0xc0) >> 6) + ((w_in[1] & 0x1) << 2));
w_out[3] = static_cast<T>((w_in[1] & 0xe) >> 1);
w_out[4] = static_cast<T>((w_in[1] & 0x70) >> 4);
w_out[5] = static_cast<T>(((w_in[1] & 0x80) >> 7) + ((w_in[2] & 0x3) << 1));
w_out[6] = static_cast<T>((w_in[2] & 0x1c) >> 2);
w_out[7] = static_cast<T>((w_in[2] & 0xe0) >> 5);
} else if (bits == 6) {
w_out[0] = static_cast<T>(w_in[0] & 0x3f);
w_out[1] =
static_cast<T>(((w_in[0] >> 6) & 0x03) + ((w_in[1] & 0x0f) << 2));
w_out[2] =
static_cast<T>(((w_in[1] >> 4) & 0x0f) + ((w_in[2] & 0x03) << 4));
w_out[3] = static_cast<T>((w_in[2] >> 2) & 0x3f);
}
}
template <typename T, int bits, int group_size> template <typename T, int bits, int group_size>
void _qmm( void _qmm(
T* result, T* result,
@@ -20,13 +45,12 @@ void _qmm(
int N, int N,
int K) { int K) {
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
const int Ng = N / group_size;
const int Nw = N / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
const uint32_t* w_local = w; const uint8_t* w_local = (const uint8_t*)w;
const T* scales_local = scales; const T* scales_local = scales;
const T* biases_local = biases; const T* biases_local = biases;
@@ -40,13 +64,25 @@ void _qmm(
T scale = *scales_local++; T scale = *scales_local++;
T bias = *biases_local++; T bias = *biases_local++;
for (int ng = 0; ng < packs_in_group; ng++) { for (int ng = 0; ng < packs_in_group; ng++) {
uint32_t wi = *w_local++; if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) { for (int p = 0; p < pack_factor; p++) {
(*result_local++) += (*result_local++) += xi * (scale * wl[p] + bias);
xi * (scale * static_cast<T>(wi & bitmask) + bias); }
wi >>= bits; w_local += bytes_per_pack;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
(*result_local++) +=
xi * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
} }
} }
} }
@@ -67,13 +103,12 @@ void _qmm_t(
int N, int N,
int K) { int K) {
constexpr int bitmask = (1 << bits) - 1; constexpr int bitmask = (1 << bits) - 1;
constexpr int pack_factor = 32 / bits; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits;
constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1;
constexpr int packs_in_group = group_size / pack_factor; constexpr int packs_in_group = group_size / pack_factor;
const int Kg = K / group_size;
const int Kw = K / pack_factor;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
const uint32_t* w_local = w; const uint8_t* w_local = (const uint8_t*)w;
const T* scales_local = scales; const T* scales_local = scales;
const T* biases_local = biases; const T* biases_local = biases;
@@ -85,12 +120,26 @@ void _qmm_t(
T bias = *biases_local++; T bias = *biases_local++;
for (int kw = 0; kw < packs_in_group; kw++) { for (int kw = 0; kw < packs_in_group; kw++) {
uint32_t wi = *w_local++; if (bits == 3 || bits == 6) {
T wl[pack_factor];
extract_bits<T, bits>(w_local, wl);
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) { for (int p = 0; p < pack_factor; p++) {
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias); sum += x_local[p] * (scale * wl[p] + bias);
wi >>= bits; }
w_local += bytes_per_pack;
x_local += pack_factor;
} else {
uint8_t wi = *w_local++;
#pragma clang loop unroll(full)
for (int p = 0; p < pack_factor; p++) {
sum +=
(*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
if (bits != 8) {
wi >>= bits;
}
}
} }
} }
} }
@@ -102,6 +151,55 @@ void _qmm_t(
} }
} }
template <typename T, int bits, int group_size>
void _qmm_dispatch_transpose(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
bool transposed_w) {
if (transposed_w) {
return _qmm_t<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, bits, group_size>(result, x, w, scales, biases, M, N, K);
}
}
template <typename T, int bits>
void _qmm_dispatch_group(
T* result,
const T* x,
const uint32_t* w,
const T* scales,
const T* biases,
int M,
int N,
int K,
int group_size,
bool transposed_w) {
switch (group_size) {
case 32:
_qmm_dispatch_transpose<T, bits, 32>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 64:
_qmm_dispatch_transpose<T, bits, 64>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
case 128:
_qmm_dispatch_transpose<T, bits, 128>(
result, x, w, scales, biases, M, N, K, transposed_w);
break;
default:
throw std::invalid_argument(
"Quantization group size must be 32, 64 or 128.");
}
}
template <typename T> template <typename T>
void _qmm_dispatch_typed( void _qmm_dispatch_typed(
T* result, T* result,
@@ -116,79 +214,29 @@ void _qmm_dispatch_typed(
int bits, int bits,
bool transposed_w) { bool transposed_w) {
switch (bits) { switch (bits) {
case 2: { case 2:
switch (group_size) { _qmm_dispatch_group<T, 2>(
case 32: result, x, w, scales, biases, M, N, K, group_size, transposed_w);
if (transposed_w) { break;
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K); case 3:
} else { _qmm_dispatch_group<T, 3>(
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K); result, x, w, scales, biases, M, N, K, group_size, transposed_w);
} break;
case 64: case 4:
if (transposed_w) { _qmm_dispatch_group<T, 4>(
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K); result, x, w, scales, biases, M, N, K, group_size, transposed_w);
} else { break;
return _qmm<T, 2, 64>(result, x, w, scales, biases, M, N, K); case 6:
} _qmm_dispatch_group<T, 6>(
case 128: result, x, w, scales, biases, M, N, K, group_size, transposed_w);
if (transposed_w) { break;
return _qmm_t<T, 2, 128>(result, x, w, scales, biases, M, N, K); case 8:
} else { _qmm_dispatch_group<T, 8>(
return _qmm<T, 2, 128>(result, x, w, scales, biases, M, N, K); result, x, w, scales, biases, M, N, K, group_size, transposed_w);
} break;
} default:
} throw std::invalid_argument("Quantization bits must be 2, 3, 4, 6 or 8.");
case 4: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 4, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 4, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
case 8: {
switch (group_size) {
case 32:
if (transposed_w) {
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
}
case 64:
if (transposed_w) {
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 64>(result, x, w, scales, biases, M, N, K);
}
case 128:
if (transposed_w) {
return _qmm_t<T, 8, 128>(result, x, w, scales, biases, M, N, K);
} else {
return _qmm<T, 8, 128>(result, x, w, scales, biases, M, N, K);
}
}
}
} }
std::ostringstream msg;
msg << "Quantization type not supported. Provided bits=" << bits
<< " and group_size=" << group_size
<< ". The supported options are bits in "
<< "{2, 4, 8} and group_size in {64, 128}.";
throw std::invalid_argument(msg.str());
} }
void _qmm_dispatch( void _qmm_dispatch(
@@ -404,4 +452,114 @@ void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
transpose_); transpose_);
} }
template <typename T, typename U>
void quantize(
const array& w_,
array& out_,
array& scales_,
array& biases_,
int bits,
int group_size) {
const T* w = w_.data<T>();
auto out = out_.data<U>();
T* scales = scales_.data<T>();
T* biases = biases_.data<T>();
T n_bins = (1 << bits) - 1;
T eps = 1e-7;
bool power_of_2_bits = is_power_of_2(bits);
int el_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
// For 3/6 bits we read 3 uint8s at a time instead of 1 uint32
int bytes_per_pack = power_of_2_bits ? 1 : 3;
int int_per_group = group_size * bytes_per_pack / el_per_int;
size_t n_groups = w_.size() / group_size;
for (size_t i = 0; i < n_groups; ++i) {
size_t w_idx = i * group_size;
T w_min = std::numeric_limits<float>::infinity();
T w_max = -w_min;
for (int j = 0; j < group_size; ++j) {
w_max = std::max(w_max, w[w_idx + j]);
w_min = std::min(w_min, w[w_idx + j]);
}
bool mask = std::abs(w_min) > std::abs(w_max);
T scale = std::max(T((w_max - w_min) / n_bins), eps);
scale = mask ? scale : -scale;
auto edge = mask ? w_min : w_max;
auto q0 = std::rint(edge / scale);
if (q0 == 0) {
scales[i] = scale;
biases[i] = 0;
} else {
scales[i] = edge / q0;
biases[i] = edge;
}
size_t out_idx = i * int_per_group;
for (int j = 0; j < int_per_group / bytes_per_pack; ++j) {
uint32_t out_el = 0;
for (int k = 0; k < el_per_int; ++k) {
T w_el = w[w_idx + j * el_per_int + k];
w_el = std::rint((w_el - biases[i]) / scales[i]);
w_el = std::min(std::max(w_el, T(0)), n_bins);
out_el |= static_cast<uint32_t>(w_el) << (k * bits);
}
if (power_of_2_bits) {
out[out_idx + j] = out_el;
} else {
out[out_idx + bytes_per_pack * j] = out_el & 0xff;
out[out_idx + bytes_per_pack * j + 1] = (out_el & 0xff00) >> 8;
out[out_idx + bytes_per_pack * j + 2] = (out_el & 0xff0000) >> 16;
}
}
}
}
void fast::AffineQuantize::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto ensure_row_contiguous = [](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
};
auto w = ensure_row_contiguous(inputs[0]);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& scales = outputs[1];
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
quantize<float16_t, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
quantize<bfloat16_t, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
quantize<float, uint32_t>(w, out, scales, biases, bits_, group_size_);
} else {
quantize<float, uint8_t>(w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -120,65 +120,73 @@ struct MinReduce {
}; };
template <typename InT> template <typename InT>
void reduce_dispatch_out( void reduce_dispatch_and_or(
const array& in, const array& in,
array& out, array& out,
Reduce::ReduceType rtype, Reduce::ReduceType rtype,
const std::vector<int>& axes) { const std::vector<int>& axes) {
switch (rtype) { if (rtype == Reduce::And) {
case Reduce::And: { reduction_op<InT, bool>(in, out, axes, true, AndReduce());
reduction_op<InT, bool>(in, out, axes, true, AndReduce()); } else {
break; reduction_op<InT, bool>(in, out, axes, false, OrReduce());
}
}
template <typename InT>
void reduce_dispatch_sum_prod(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Sum) {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
} }
case Reduce::Or: { } else {
reduction_op<InT, bool>(in, out, axes, false, OrReduce()); auto op = [](auto y, auto x) { (*y) *= x; };
break; if constexpr (std::is_integral_v<InT> && sizeof(InT) <= 4) {
} reduction_op<InT, int32_t>(in, out, axes, 1, op);
case Reduce::Sum: { } else {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if (out.dtype() == int32) {
// special case since the input type can be bool
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
}
break;
}
case Reduce::Prod: {
auto op = [](auto y, auto x) { (*y) *= x; };
reduction_op<InT, InT>(in, out, axes, 1, op); reduction_op<InT, InT>(in, out, axes, 1, op);
break;
}
case Reduce::Max: {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
break;
}
case Reduce::Min: {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
break;
} }
} }
} }
template <typename InT>
void reduce_dispatch_min_max(
const array& in,
array& out,
Reduce::ReduceType rtype,
const std::vector<int>& axes) {
if (rtype == Reduce::Max) {
auto init = Limits<InT>::min;
reduction_op<InT, InT>(in, out, axes, init, MaxReduce());
} else {
auto init = Limits<InT>::max;
reduction_op<InT, InT>(in, out, axes, init, MinReduce());
}
}
} // namespace } // namespace
void nd_loop( void nd_loop(
std::function<void(int)> callback, std::function<void(int)> callback,
const std::vector<int>& shape, const Shape& shape,
const std::vector<size_t>& strides) { const Strides& strides) {
std::function<void(int, int)> loop_inner; std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) { loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) { if (dim < shape.size() - 1) {
int size = shape[dim]; auto size = shape[dim];
size_t stride = strides[dim]; auto stride = strides[dim];
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride); loop_inner(dim + 1, offset + i * stride);
} }
} else { } else {
int size = shape[dim]; auto size = shape[dim];
size_t stride = strides[dim]; auto stride = strides[dim];
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
callback(offset + i * stride); callback(offset + i * stride);
} }
@@ -190,46 +198,114 @@ void nd_loop(
void Reduce::eval(const std::vector<array>& inputs, array& out) { void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
switch (in.dtype()) { switch (reduce_type_) {
case bool_: case Reduce::And:
reduce_dispatch_out<bool>(in, out, reduce_type_, axes_); case Reduce::Or: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_and_or<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
case float16:
case bfloat16:
reduce_dispatch_and_or<int16_t>(in, out, reduce_type_, axes_);
break;
case uint32:
case int32:
case float32:
reduce_dispatch_and_or<int32_t>(in, out, reduce_type_, axes_);
break;
case uint64:
case int64:
case complex64:
reduce_dispatch_and_or<int64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
case uint8: }
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_); case Reduce::Sum:
case Reduce::Prod: {
switch (in.dtype()) {
case bool_:
case uint8:
case int8:
reduce_dispatch_sum_prod<int8_t>(in, out, reduce_type_, axes_);
break;
case int16:
case uint16:
reduce_dispatch_sum_prod<int16_t>(in, out, reduce_type_, axes_);
break;
case int32:
case uint32:
reduce_dispatch_sum_prod<int32_t>(in, out, reduce_type_, axes_);
break;
case int64:
case uint64:
reduce_dispatch_sum_prod<int64_t>(in, out, reduce_type_, axes_);
break;
case float16:
reduce_dispatch_sum_prod<float16_t>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_sum_prod<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_sum_prod<float>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_sum_prod<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
case uint16: }
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_); case Reduce::Max:
break; case Reduce::Min: {
case uint32: switch (in.dtype()) {
reduce_dispatch_out<uint32_t>(in, out, reduce_type_, axes_); case bool_:
break; reduce_dispatch_min_max<bool>(in, out, reduce_type_, axes_);
case uint64: break;
reduce_dispatch_out<uint64_t>(in, out, reduce_type_, axes_); case uint8:
break; reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
case int8: break;
reduce_dispatch_out<uint8_t>(in, out, reduce_type_, axes_); case uint16:
break; reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
case int16: break;
reduce_dispatch_out<uint16_t>(in, out, reduce_type_, axes_); case uint32:
break; reduce_dispatch_min_max<uint32_t>(in, out, reduce_type_, axes_);
case int32: break;
reduce_dispatch_out<int32_t>(in, out, reduce_type_, axes_); case uint64:
break; reduce_dispatch_min_max<uint64_t>(in, out, reduce_type_, axes_);
case int64: break;
reduce_dispatch_out<int64_t>(in, out, reduce_type_, axes_); case int8:
break; reduce_dispatch_min_max<uint8_t>(in, out, reduce_type_, axes_);
case float16: break;
reduce_dispatch_out<float16_t>(in, out, reduce_type_, axes_); case int16:
break; reduce_dispatch_min_max<uint16_t>(in, out, reduce_type_, axes_);
case float32: break;
reduce_dispatch_out<float>(in, out, reduce_type_, axes_); case int32:
break; reduce_dispatch_min_max<int32_t>(in, out, reduce_type_, axes_);
case bfloat16: break;
reduce_dispatch_out<bfloat16_t>(in, out, reduce_type_, axes_); case int64:
break; reduce_dispatch_min_max<int64_t>(in, out, reduce_type_, axes_);
case complex64: break;
reduce_dispatch_out<complex64_t>(in, out, reduce_type_, axes_); case float16:
reduce_dispatch_min_max<float16_t>(in, out, reduce_type_, axes_);
break;
case float32:
reduce_dispatch_min_max<float>(in, out, reduce_type_, axes_);
break;
case bfloat16:
reduce_dispatch_min_max<bfloat16_t>(in, out, reduce_type_, axes_);
break;
case complex64:
reduce_dispatch_min_max<complex64_t>(in, out, reduce_type_, axes_);
break;
}
break; break;
}
} }
} }

View File

@@ -38,13 +38,10 @@ enum ReductionOpType {
struct ReductionPlan { struct ReductionPlan {
ReductionOpType type; ReductionOpType type;
std::vector<int> shape; Shape shape;
std::vector<size_t> strides; Strides strides;
ReductionPlan( ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
ReductionOpType type_,
std::vector<int> shape_,
std::vector<size_t> strides_)
: type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {} : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
ReductionPlan(ReductionOpType type_) : type(type_) {} ReductionPlan(ReductionOpType type_) : type(type_) {}
}; };
@@ -55,10 +52,10 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
// Should this be in utils? // Should this be in utils?
void nd_loop( void nd_loop(
std::function<void(int)> callback, std::function<void(int)> callback,
const std::vector<int>& shape, const Shape& shape,
const std::vector<size_t>& strides); const Strides& strides);
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes( std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x, const array& x,
const std::vector<int>& axes); const std::vector<int>& axes);
@@ -113,9 +110,6 @@ void reduction_op(
return; return;
} }
std::vector<int> shape;
std::vector<size_t> strides;
if (plan.type == ContiguousReduce && plan.shape.size() == 1) { if (plan.type == ContiguousReduce && plan.shape.size() == 1) {
int reduction_size = plan.shape[0]; int reduction_size = plan.shape[0];
const T* x_ptr = x.data<T>(); const T* x_ptr = x.data<T>();
@@ -135,7 +129,7 @@ void reduction_op(
U* out_ptr = out.data<U>(); U* out_ptr = out.data<U>();
// Unrolling the following loop (and implementing it in order for // Unrolling the following loop (and implementing it in order for
// ContiguousReduce) should hold extra performance boost. // ContiguousReduce) should hold extra performance boost.
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) { if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i++, out_ptr++) { for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
@@ -181,7 +175,7 @@ void reduction_op(
plan.strides.pop_back(); plan.strides.pop_back();
const T* x_ptr = x.data<T>(); const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>(); U* out_ptr = out.data<U>();
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
if (plan.shape.size() == 0) { if (plan.shape.size() == 0) {
for (int i = 0; i < out.size(); i += reduction_stride) { for (int i = 0; i < out.size(); i += reduction_stride) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
@@ -211,7 +205,7 @@ void reduction_op(
if (plan.type == GeneralReduce) { if (plan.type == GeneralReduce) {
const T* x_ptr = x.data<T>(); const T* x_ptr = x.data<T>();
U* out_ptr = out.data<U>(); U* out_ptr = out.data<U>();
std::tie(shape, strides) = shapes_without_reduction_axes(x, axes); auto [shape, strides] = shapes_without_reduction_axes(x, axes);
for (int i = 0; i < out.size(); i++, out_ptr++) { for (int i = 0; i < out.size(); i++, out_ptr++) {
int offset = elem_to_loc(i, shape, strides); int offset = elem_to_loc(i, shape, strides);
U val = init; U val = init;

View File

@@ -4,11 +4,11 @@
namespace mlx::core { namespace mlx::core {
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes( std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x, const array& x,
const std::vector<int>& axes) { const std::vector<int>& axes) {
std::vector<int> shape = x.shape(); auto shape = x.shape();
std::vector<size_t> strides = x.strides(); auto strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) { for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i]; int a = axes[i];
@@ -29,8 +29,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Row contiguous input so the output is row contiguous // Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) { if (x.flags().row_contiguous) {
// Merge consecutive axes // Merge consecutive axes
std::vector<int> shape = {x.shape(axes[0])}; Shape shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]}; Strides strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) { for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) { if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
shape.back() *= x.shape(axes[i]); shape.back() *= x.shape(axes[i]);
@@ -69,7 +69,7 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Sort reduction axes by stride in order to merge them and figure out if we // Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction. // have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions; std::vector<std::pair<int, int64_t>> reductions;
for (auto a : axes) { for (auto a : axes) {
if (x.shape(a) > 1) { if (x.shape(a) > 1) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a])); reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
@@ -93,8 +93,8 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
} }
} }
std::vector<int> shape; Shape shape;
std::vector<size_t> strides; Strides strides;
for (auto r : reductions) { for (auto r : reductions) {
shape.push_back(r.first); shape.push_back(r.first);
strides.push_back(r.second); strides.push_back(r.second);
@@ -109,15 +109,15 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// Delegate to the general strided reduction op if the axes after // Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous. // strides.back() are contiguous.
if (strides.back() > 1) { if (strides.back() > 1) {
int size = 1; int64_t size = 1;
bool have_expand = false; bool have_expand = false;
for (int i = x.ndim() - 1; i >= 0; i--) { for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) { if (axes.back() == i) {
continue; continue;
} }
size_t stride_i = x.strides()[i]; auto stride_i = x.strides()[i];
int shape_i = x.shape(i); auto shape_i = x.shape(i);
if (stride_i == 0) { if (stride_i == 0) {
if (shape_i == 1) { if (shape_i == 1) {
continue; continue;

View File

@@ -4,24 +4,22 @@
namespace mlx::core { namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice( std::tuple<int64_t, Strides> prepare_slice(
const array& in, const array& in,
const std::vector<int>& start_indices, const Shape& start_indices,
const std::vector<int>& strides) { const Shape& strides) {
int64_t data_offset = 0; int64_t data_offset = 0;
bool copy_needed = false; Strides inp_strides(in.ndim(), 0);
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) { for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i]; data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i]; inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0;
} }
return std::make_tuple(copy_needed, data_offset, inp_strides); return std::make_tuple(data_offset, inp_strides);
} }
void shared_buffer_slice( void shared_buffer_slice(
const array& in, const array& in,
const std::vector<size_t>& out_strides, const Strides& out_strides,
size_t data_offset, size_t data_offset,
size_t data_size, size_t data_size,
array& out) { array& out) {
@@ -34,7 +32,7 @@ void shared_buffer_slice(
flags.col_contiguous = is_col_contiguous; flags.col_contiguous = is_col_contiguous;
flags.contiguous = (no_bsx_size == data_size); flags.contiguous = (no_bsx_size == data_size);
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); move_or_copy(in, out, out_strides, flags, data_size, data_offset);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -6,14 +6,14 @@
namespace mlx::core { namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice( std::tuple<int64_t, Strides> prepare_slice(
const array& in, const array& in,
const std::vector<int>& start_indices, const Shape& start_indices,
const std::vector<int>& strides); const Shape& strides);
void shared_buffer_slice( void shared_buffer_slice(
const array& in, const array& in,
const std::vector<size_t>& out_strides, const Strides& out_strides,
size_t data_offset, size_t data_offset,
size_t data_size, size_t data_size,
array& out); array& out);

View File

@@ -25,7 +25,7 @@ struct StridedIterator {
// Constructors // Constructors
StridedIterator() = default; StridedIterator() = default;
explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0) explicit StridedIterator(T* ptr, int64_t stride, difference_type offset = 0)
: ptr_(ptr + offset * stride), stride_(stride) {} : ptr_(ptr + offset * stride), stride_(stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0) explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
@@ -99,7 +99,7 @@ struct StridedIterator {
} }
private: private:
size_t stride_; int64_t stride_;
T* ptr_; T* ptr_;
}; };
@@ -120,11 +120,11 @@ void sort(const array& in, array& out, int axis) {
auto remaining_strides = out.strides(); auto remaining_strides = out.strides();
remaining_strides.erase(remaining_strides.begin() + axis); remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = out.strides()[axis]; auto axis_stride = out.strides()[axis];
int axis_size = out.shape(axis); auto axis_size = out.shape(axis);
// Perform sorting in place // Perform sorting in place
ContiguousIterator<size_t> src_it( ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size()); remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc; T* data_ptr = out.data<T>() + src_it.loc;
@@ -158,14 +158,14 @@ void argsort(const array& in, array& out, int axis) {
auto out_remaining_strides = out.strides(); auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis); out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis]; auto in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis]; auto out_stride = out.strides()[axis];
int axis_size = in.shape(axis); auto axis_size = in.shape(axis);
// Perform sorting // Perform sorting
ContiguousIterator<size_t> in_it( ContiguousIterator in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it( ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc; const T* data_ptr = in.data<T>() + in_it.loc;
@@ -208,13 +208,13 @@ void partition(const array& in, array& out, int axis, int kth) {
auto remaining_strides = in.strides(); auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis); remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis]; auto axis_stride = in.strides()[axis];
int axis_size = in.shape(axis); int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth; kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place // Perform partition in place
ContiguousIterator<size_t> src_it( ContiguousIterator src_it(
remaining_shape, remaining_strides, remaining_shape.size()); remaining_shape, remaining_strides, remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
T* data_ptr = out.data<T>() + src_it.loc; T* data_ptr = out.data<T>() + src_it.loc;
@@ -249,16 +249,16 @@ void argpartition(const array& in, array& out, int axis, int kth) {
auto out_remaining_strides = out.strides(); auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis); out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis]; auto in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis]; auto out_stride = out.strides()[axis];
int axis_size = in.shape(axis); auto axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth; kth = kth < 0 ? kth + axis_size : kth;
// Perform partition // Perform partition
ContiguousIterator<size_t> in_it( ContiguousIterator in_it(
in_remaining_shape, in_remaining_strides, in_remaining_shape.size()); in_remaining_shape, in_remaining_strides, in_remaining_shape.size());
ContiguousIterator<size_t> out_it( ContiguousIterator out_it(
out_remaining_shape, out_remaining_strides, out_remaining_shape.size()); out_remaining_shape, out_remaining_strides, out_remaining_shape.size());
for (int i = 0; i < n_rows; i++) { for (int i = 0; i < n_rows; i++) {
const T* data_ptr = in.data<T>() + in_it.loc; const T* data_ptr = in.data<T>() + in_it.loc;

View File

@@ -78,11 +78,11 @@ void ternary_op_dims(
const T3* c, const T3* c,
U* out, U* out,
Op op, Op op,
const std::vector<int>& shape, const Shape& shape,
const std::vector<size_t>& a_strides, const Strides& a_strides,
const std::vector<size_t>& b_strides, const Strides& b_strides,
const std::vector<size_t>& c_strides, const Strides& c_strides,
const std::vector<size_t>& out_strides, const Strides& out_strides,
int axis) { int axis) {
auto stride_a = a_strides[axis]; auto stride_a = a_strides[axis];
auto stride_b = b_strides[axis]; auto stride_b = b_strides[axis];
@@ -164,10 +164,10 @@ void ternary_op_dispatch_dims(
return; return;
} }
ContiguousIterator<size_t> a_it(shape, a_strides, ndim - 2); ContiguousIterator a_it(shape, a_strides, ndim - 2);
ContiguousIterator<size_t> b_it(shape, b_strides, ndim - 2); ContiguousIterator b_it(shape, b_strides, ndim - 2);
ContiguousIterator<size_t> c_it(shape, c_strides, ndim - 2); ContiguousIterator c_it(shape, c_strides, ndim - 2);
size_t stride = out_strides[ndim - 3]; auto stride = out_strides[ndim - 3];
for (size_t elem = 0; elem < a.size(); elem += stride) { for (size_t elem = 0; elem < a.size(); elem += stride) {
ternary_op_dims<T1, T2, T3, U, Op, 2>( ternary_op_dims<T1, T2, T3, U, Op, 2>(
a_ptr + a_it.loc, a_ptr + a_it.loc,

View File

@@ -4,15 +4,35 @@
namespace mlx::core { namespace mlx::core {
template <typename StrideT> void move_or_copy(const array& in, array& out) {
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>> if (in.is_donatable()) {
collapse_contiguous_dims_impl( out.move_shared_buffer(in);
const std::vector<int>& shape, } else {
const std::vector<std::vector<StrideT>>& strides, out.copy_shared_buffer(in);
StrideT size_cap) { }
}
void move_or_copy(
const array& in,
array& out,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
if (in.is_donatable()) {
out.move_shared_buffer(in, strides, flags, data_size, offset);
} else {
out.copy_shared_buffer(in, strides, flags, data_size, offset);
}
}
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
const Shape& shape,
const std::vector<Strides>& strides,
int64_t size_cap) {
// Make a vector that has axes separated with -1. Collapse all axes between // Make a vector that has axes separated with -1. Collapse all axes between
// -1. // -1.
std::vector<int> to_collapse; Shape to_collapse;
if (shape.size() > 0) { if (shape.size() > 0) {
if (shape[0] != 1) { if (shape[0] != 1) {
to_collapse.push_back(0); to_collapse.push_back(0);
@@ -21,7 +41,7 @@ collapse_contiguous_dims_impl(
for (int i = 1; i < shape.size(); i++) { for (int i = 1; i < shape.size(); i++) {
bool contiguous = true; bool contiguous = true;
size *= shape[i]; size *= shape[i];
for (const std::vector<StrideT>& st : strides) { for (const auto& st : strides) {
if (st[i] * shape[i] != st[i - 1] || size > size_cap) { if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
contiguous = false; contiguous = false;
size = shape[i]; size = shape[i];
@@ -38,8 +58,8 @@ collapse_contiguous_dims_impl(
to_collapse.push_back(-1); to_collapse.push_back(-1);
} }
std::vector<int> out_shape; Shape out_shape;
std::vector<std::vector<StrideT>> out_strides(strides.size()); std::vector<Strides> out_strides(strides.size());
for (int i = 0;;) { for (int i = 0;;) {
while (i < to_collapse.size() && to_collapse[i] == -1) { while (i < to_collapse.size() && to_collapse[i] == -1) {
++i; ++i;
@@ -54,7 +74,7 @@ collapse_contiguous_dims_impl(
} }
out_shape.push_back(current_shape); out_shape.push_back(current_shape);
for (int j = 0; j < strides.size(); j++) { for (int j = 0; j < strides.size(); j++) {
const std::vector<StrideT>& st = strides[j]; const auto& st = strides[j];
out_strides[j].push_back(st[to_collapse[k - 1]]); out_strides[j].push_back(st[to_collapse[k - 1]]);
} }
i = k + 1; i = k + 1;
@@ -69,29 +89,12 @@ collapse_contiguous_dims_impl(
return std::make_tuple(out_shape, out_strides); return std::make_tuple(out_shape, out_strides);
} }
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>> std::pair<Shape, Strides> collapse_contiguous_dims(
collapse_contiguous_dims( const Shape& shape,
const std::vector<int>& shape, const Strides& strides,
const std::vector<std::vector<int64_t>>& strides, int64_t size_cap) {
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) { Shape collapsed_shape;
return collapse_contiguous_dims_impl(shape, strides, size_cap); Strides collapsed_strides;
}
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap /* = std::numeric_limits<int32>::max() */) {
return collapse_contiguous_dims_impl(shape, strides, size_cap);
}
template <typename StrideT>
std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
const std::vector<int>& shape,
const std::vector<StrideT>& strides,
StrideT size_cap) {
std::vector<int> collapsed_shape;
std::vector<StrideT> collapsed_strides;
if (shape.size() > 0) { if (shape.size() > 0) {
collapsed_shape.push_back(shape[0]); collapsed_shape.push_back(shape[0]);
@@ -101,7 +104,7 @@ std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
continue; continue;
} else if ( } else if (
strides[i] * shape[i] != collapsed_strides.back() || strides[i] * shape[i] != collapsed_strides.back() ||
collapsed_shape.back() * static_cast<StrideT>(shape[i]) > size_cap) { collapsed_shape.back() * static_cast<int64_t>(shape[i]) > size_cap) {
collapsed_shape.push_back(shape[i]); collapsed_shape.push_back(shape[i]);
collapsed_strides.push_back(strides[i]); collapsed_strides.push_back(strides[i]);
} else { } else {
@@ -114,25 +117,10 @@ std::pair<std::vector<int>, std::vector<StrideT>> collapse_contiguous_dims_impl(
return std::make_pair(collapsed_shape, collapsed_strides); return std::make_pair(collapsed_shape, collapsed_strides);
} }
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims( std::pair<Shape, Strides> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<int64_t>& strides,
int64_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<int64_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap /* = std::numeric_limits<int32_t>::max() */) {
return collapse_contiguous_dims_impl<size_t>(shape, strides, size_cap);
}
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a, const array& a,
size_t size_cap /* = std::numeric_limits<int32_t>::max()*/) { int64_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
return collapse_contiguous_dims_impl<size_t>( return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
a.shape(), a.strides(), size_cap);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -8,12 +8,9 @@
namespace mlx::core { namespace mlx::core {
template <typename StrideT> inline int64_t
inline StrideT elem_to_loc( elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
int elem, int64_t loc = 0;
const std::vector<int>& shape,
const std::vector<StrideT>& strides) {
StrideT loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) { for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]); auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i]; loc += q_and_r.rem * strides[i];
@@ -22,16 +19,15 @@ inline StrideT elem_to_loc(
return loc; return loc;
} }
inline size_t elem_to_loc(int elem, const array& a) { inline int64_t elem_to_loc(int elem, const array& a) {
if (a.flags().row_contiguous) { if (a.flags().row_contiguous) {
return elem; return elem;
} }
return elem_to_loc(elem, a.shape(), a.strides()); return elem_to_loc(elem, a.shape(), a.strides());
} }
template <typename StrideT> inline Strides make_contiguous_strides(const Shape& shape) {
std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) { Strides strides(shape.size(), 1);
std::vector<StrideT> strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) { for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i]; strides[i - 1] = strides[i] * shape[i];
} }
@@ -44,22 +40,15 @@ std::vector<StrideT> make_contiguous_strides(const std::vector<int>& shape) {
// //
// When multiple arrays are passed they should all have the same shape. The // When multiple arrays are passed they should all have the same shape. The
// collapsed axes are also the same so one shape is returned. // collapsed axes are also the same so one shape is returned.
std::tuple<std::vector<int>, std::vector<std::vector<int64_t>>> std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
collapse_contiguous_dims( const Shape& shape,
const std::vector<int>& shape, const std::vector<Strides>& strides,
const std::vector<std::vector<int64_t>>& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max()); int64_t size_cap = std::numeric_limits<int32_t>::max());
std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<std::vector<size_t>>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>> inline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
collapse_contiguous_dims(
const std::vector<array>& xs, const std::vector<array>& xs,
size_t size_cap = std::numeric_limits<int32_t>::max()) { size_t size_cap = std::numeric_limits<int32_t>::max()) {
std::vector<std::vector<size_t>> strides; std::vector<Strides> strides;
for (auto& x : xs) { for (auto& x : xs) {
strides.emplace_back(x.strides()); strides.emplace_back(x.strides());
} }
@@ -73,19 +62,14 @@ inline auto collapse_contiguous_dims(Arrays&&... xs) {
} }
// The single array version of the above. // The single array version of the above.
std::pair<std::vector<int>, std::vector<int64_t>> collapse_contiguous_dims( std::pair<Shape, Strides> collapse_contiguous_dims(
const std::vector<int>& shape, const Shape& shape,
const std::vector<int64_t>& strides, const Strides& strides,
int64_t size_cap = std::numeric_limits<int32_t>::max()); int64_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims( std::pair<Shape, Strides> collapse_contiguous_dims(
const std::vector<int>& shape,
const std::vector<size_t>& strides,
size_t size_cap = std::numeric_limits<int32_t>::max());
std::pair<std::vector<int>, std::vector<size_t>> collapse_contiguous_dims(
const array& a, const array& a,
size_t size_cap = std::numeric_limits<int32_t>::max()); int64_t size_cap = std::numeric_limits<int32_t>::max());
template <typename StrideT>
struct ContiguousIterator { struct ContiguousIterator {
inline void step() { inline void step() {
int dims = shape_.size(); int dims = shape_.size();
@@ -102,7 +86,7 @@ struct ContiguousIterator {
loc += strides_[i]; loc += strides_[i];
} }
void seek(StrideT n) { void seek(int64_t n) {
loc = 0; loc = 0;
for (int i = shape_.size() - 1; i >= 0; --i) { for (int i = shape_.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(n, shape_[i]); auto q_and_r = ldiv(n, shape_[i]);
@@ -128,32 +112,29 @@ struct ContiguousIterator {
} }
explicit ContiguousIterator( explicit ContiguousIterator(
const std::vector<int>& shape, const Shape& shape,
const std::vector<StrideT>& strides, const Strides& strides,
int dims) int dims)
: shape_(shape.begin(), shape.begin() + dims), : shape_(shape.begin(), shape.begin() + dims),
strides_(strides.begin(), strides.begin() + dims) { strides_(strides.begin(), strides.begin() + dims) {
if (!shape_.empty()) { if (!shape_.empty()) {
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
pos_ = std::vector<int>(shape_.size(), 0); pos_ = Shape(shape_.size(), 0);
} }
} }
StrideT loc{0}; int64_t loc{0};
private: private:
std::vector<int> shape_; Shape shape_;
std::vector<StrideT> strides_; Strides strides_;
std::vector<int> pos_; Shape pos_;
}; };
template <typename StrideT> inline auto check_contiguity(const Shape& shape, const Strides& strides) {
inline auto check_contiguity(
const std::vector<int>& shape,
const std::vector<StrideT>& strides) {
size_t no_broadcast_data_size = 1; size_t no_broadcast_data_size = 1;
size_t f_stride = 1; int64_t f_stride = 1;
size_t b_stride = 1; int64_t b_stride = 1;
bool is_row_contiguous = true; bool is_row_contiguous = true;
bool is_col_contiguous = true; bool is_col_contiguous = true;
@@ -178,4 +159,13 @@ inline bool is_donatable(const array& in, const array& out) {
in.buffer_size() <= out.nbytes() + donation_extra; in.buffer_size() <= out.nbytes() + donation_extra;
} }
void move_or_copy(const array& in, array& out);
void move_or_copy(
const array& in,
array& out,
const Strides& strides,
array::Flags flags,
size_t data_size,
size_t offset = 0);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -14,14 +14,21 @@ function(make_jit_source SRC_FILE)
COMMAND COMMAND
/bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh /bin/bash ${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/jit ${CMAKE_C_COMPILER} ${PROJECT_SOURCE_DIR}
${SRC_FILE} "-DMLX_METAL_VERSION=${MLX_METAL_VERSION}" ${SRC_FILE}
DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN}) DEPENDS make_compiled_preamble.sh kernels/${SRC_FILE}.h ${ARGN})
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp) add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
add_dependencies(mlx ${SRC_NAME}) add_dependencies(mlx ${SRC_NAME})
target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp) target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp)
endfunction(make_jit_source) endfunction(make_jit_source)
make_jit_source(utils kernels/bf16.h kernels/complex.h kernels/defines.h) make_jit_source(
utils
kernels/jit/bf16.h
kernels/metal_3_0/bf16.h
kernels/metal_3_1/bf16.h
kernels/bf16_math.h
kernels/complex.h
kernels/defines.h)
make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h) make_jit_source(unary_ops kernels/erf.h kernels/expm1f.h)
make_jit_source(binary_ops) make_jit_source(binary_ops)
make_jit_source(ternary_ops) make_jit_source(ternary_ops)

View File

@@ -30,7 +30,7 @@ BufferCache::BufferCache(MTL::Device* device)
: device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} : device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {}
BufferCache::~BufferCache() { BufferCache::~BufferCache() {
auto thread_pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
clear(); clear();
} }
@@ -155,11 +155,13 @@ MetalAllocator::MetalAllocator()
} }
size_t MetalAllocator::set_cache_limit(size_t limit) { size_t MetalAllocator::set_cache_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, max_pool_size_); std::swap(limit, max_pool_size_);
return limit; return limit;
}; };
size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) { size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
std::unique_lock lk(mutex_);
std::swap(limit, block_limit_); std::swap(limit, block_limit_);
relaxed_ = relaxed; relaxed_ = relaxed;
gc_limit_ = std::min( gc_limit_ = std::min(
@@ -169,6 +171,7 @@ size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) {
}; };
size_t MetalAllocator::set_wired_limit(size_t limit) { size_t MetalAllocator::set_wired_limit(size_t limit) {
std::unique_lock lk(mutex_);
std::swap(limit, wired_limit_); std::swap(limit, wired_limit_);
residency_set_.resize(wired_limit_); residency_set_.resize(wired_limit_);
return limit; return limit;
@@ -205,7 +208,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
return Buffer{nullptr}; return Buffer{nullptr};
} }
auto thread_pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache // try to reclaim memory from the cache
@@ -226,7 +229,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Maintain the cache below the requested limit // Maintain the cache below the requested limit
if (get_cache_memory() >= max_pool_size_) { if (get_cache_memory() >= max_pool_size_) {
auto thread_pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
} }
@@ -237,11 +240,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
void MetalAllocator::clear_cache() { void MetalAllocator::clear_cache() {
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
auto pool = metal::new_scoped_memory_pool();
buffer_cache_.clear(); buffer_cache_.clear();
} }
void MetalAllocator::free(Buffer buffer) { void MetalAllocator::free(Buffer buffer) {
auto buf = static_cast<MTL::Buffer*>(buffer.ptr()); auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
if (buf == nullptr) {
return;
}
std::unique_lock lk(mutex_); std::unique_lock lk(mutex_);
residency_set_.erase(buf); residency_set_.erase(buf);
active_memory_ -= buf->length(); active_memory_ -= buf->length();
@@ -249,7 +256,7 @@ void MetalAllocator::free(Buffer buffer) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
lk.unlock(); lk.unlock();
auto thread_pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
buf->release(); buf->release();
} }
} }

View File

@@ -22,37 +22,37 @@ std::string get_kernel_name(
BinaryOpType bopt, BinaryOpType bopt,
const std::string& op, const std::string& op,
const array& a, const array& a,
bool use_2d, bool large,
int ndim, int ndim,
int work_per_thread) { int work_per_thread) {
std::ostringstream kname; std::string kname;
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
kname << "ss"; kname = "ss";
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
kname << (use_2d ? "sv2" : "sv"); kname = (large ? "sv2" : "sv");
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
kname << (use_2d ? "vs2" : "vs"); kname = (large ? "vs2" : "vs");
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
kname << (use_2d ? "vv2" : "vv"); kname = (large ? "vv2" : "vv");
break; break;
case BinaryOpType::General: case BinaryOpType::General:
kname << "g"; kname = "g";
if (ndim <= 3) { if (ndim <= 3) {
kname << ndim; kname += std::to_string(ndim);
} else { } else {
kname << "n"; concatenate(kname, "n", std::to_string(work_per_thread));
if (work_per_thread > 1) { }
kname << work_per_thread; if (large) {
} kname += "large";
} }
break; break;
} }
kname << "_" << op << type_to_name(a); concatenate(kname, "_", op, type_to_name(a));
return kname.str(); return kname;
} }
void binary_op_gpu_inplace( void binary_op_gpu_inplace(
@@ -75,24 +75,30 @@ void binary_op_gpu_inplace(
auto [shape, strides] = collapse_contiguous_dims(a, b, out); auto [shape, strides] = collapse_contiguous_dims(a, b, out);
return std::make_tuple(shape, strides[0], strides[1], strides[2]); return std::make_tuple(shape, strides[0], strides[1], strides[2]);
} else { } else {
std::vector<size_t> e; decltype(a.strides()) e{};
return std::make_tuple(std::vector<int>{}, e, e, e); return std::make_tuple(decltype(a.shape()){}, e, e, e);
} }
}; };
auto [shape, strides_a, strides_b, strides_out] = maybe_collapse(); auto [shape, strides_a, strides_b, strides_out] = maybe_collapse();
bool use_2d = out.data_size() > UINT32_MAX; bool large = out.data_size() > UINT32_MAX;
auto ndim = shape.size(); auto ndim = shape.size();
int work_per_thread = (bopt == BinaryOpType::General) ? 4 : 1; int work_per_thread;
if (bopt == BinaryOpType::General) {
large |= (a.data_size() > UINT32_MAX || b.data_size() > UINT32_MAX);
work_per_thread = large ? 4 : 2;
} else {
work_per_thread = 1;
}
std::string kernel_name = std::string kernel_name =
get_kernel_name(bopt, op, a, use_2d, shape.size(), work_per_thread); get_kernel_name(bopt, op, a, large, shape.size(), work_per_thread);
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto kernel = outputs.size() == 2 auto kernel = outputs.size() == 2
? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op) ? get_binary_two_kernel(d, kernel_name, a.dtype(), out.dtype(), op)
: get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op); : get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// - If a is donated it goes to the first output // - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated // - If b is donated it goes to the first output if a was not donated
@@ -117,19 +123,15 @@ void binary_op_gpu_inplace(
size_t rest = out.size() / (dim0 * dim1); size_t rest = out.size() / (dim0 * dim1);
if (ndim > 3) { if (ndim > 3) {
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), arg_idx++); compute_encoder.set_vector_bytes(shape, arg_idx++);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(strides_a, arg_idx++);
strides_a.data(), ndim * sizeof(size_t), arg_idx++); compute_encoder.set_vector_bytes(strides_b, arg_idx++);
compute_encoder->setBytes( compute_encoder.set_bytes<int>(ndim, arg_idx++);
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
compute_encoder->setBytes(&ndim, sizeof(int), arg_idx++);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} else { } else {
// The shape is implicit in the grid for <= 3D // The shape is implicit in the grid for <= 3D
compute_encoder->setBytes( compute_encoder.set_vector_bytes(strides_a, arg_idx++);
strides_a.data(), ndim * sizeof(size_t), arg_idx++); compute_encoder.set_vector_bytes(strides_b, arg_idx++);
compute_encoder->setBytes(
strides_b.data(), ndim * sizeof(size_t), arg_idx++);
} }
if (thread_group_size != 1024) { if (thread_group_size != 1024) {
@@ -137,7 +139,7 @@ void binary_op_gpu_inplace(
} }
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
// Launch a 1D or 2D grid of threads // Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
@@ -145,9 +147,9 @@ void binary_op_gpu_inplace(
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }

View File

@@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <iostream> //TODO
#include <sstream> #include <sstream>
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
@@ -11,12 +12,12 @@
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core { namespace mlx::core {
constexpr int WORK_PER_THREAD = 4;
inline void build_kernel( inline void build_kernel(
std::ostream& os, std::string& os,
const std::string& kernel_name, const std::string& kernel_name,
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs, const std::vector<array>& outputs,
@@ -41,8 +42,8 @@ inline void build_kernel(
int cnt = 0; int cnt = 0;
// Start the kernel // Start the kernel
os << "[[host_name(\"" << kernel_name << "\")]]\n" os += fmt::format(
<< "[[kernel]] void " << kernel_name << "(\n"; "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
// Add the input arguments // Add the input arguments
for (auto& x : inputs) { for (auto& x : inputs) {
@@ -54,51 +55,61 @@ inline void build_kernel(
} }
// Scalars and contiguous need no strides // Scalars and contiguous need no strides
if (is_scalar(x) || contiguous) { if (!is_scalar(x) && !contiguous) {
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]],\n";
} else {
add_indices = true; add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]],\n";
} }
os += fmt::format(
" device const {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
xname,
cnt++);
} }
if (add_indices) { if (add_indices) {
os << " constant const size_t* in_strides [[buffer(" << cnt++ os += fmt::format(
<< ")]],\n"; " constant const int64_t* in_strides [[buffer({0})]],\n", cnt++);
} }
// Add the output arguments // Add the output arguments
for (auto& x : outputs) { for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* " os += fmt::format(
<< namer.get_name(x) << " [[buffer(" << cnt++ << ")]],\n"; " device {0}* {1} [[buffer({2})]],\n",
get_type_string(x.dtype()),
namer.get_name(x),
cnt++);
} }
// Add output strides and shape to extract the indices. // Add output strides and shape to extract the indices.
if (!contiguous) { if (!contiguous) {
os << " constant const size_t* output_strides [[buffer(" << cnt++ os += fmt::format(
<< ")]],\n" " constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
<< " constant const int* output_shape [[buffer(" << cnt++ << ")]],\n"; os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++);
} }
if (dynamic_dims) { if (dynamic_dims) {
os << " constant const int& ndim [[buffer(" << cnt++ << ")]],\n"; os += fmt::format(" constant const int& ndim [[buffer({0})]],\n", cnt++);
} }
// The thread index in the whole grid // The thread index in the whole grid
os << " uint3 pos [[thread_position_in_grid]],\n" os += " uint3 pos [[thread_position_in_grid]],\n";
<< " uint3 grid [[threads_per_grid]]) {\n"; os += " uint3 grid [[threads_per_grid]]) {\n";
if (use_big_index) { std::string idx_type = use_big_index ? "int64_t" : "uint";
if (contiguous && use_big_index) {
// This is only used for contiguous kernels which don't have // This is only used for contiguous kernels which don't have
// a third grid dimension // a third grid dimension
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n"; os += " int64_t index = pos.x + grid.x * int64_t(pos.y);\n";
} else if (work_per_thread > 1) { } else if (work_per_thread > 1) {
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n" os += fmt::format(" constexpr int N_ = {0};\n", work_per_thread);
<< " int xshape = output_shape[" os += fmt::format(
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n" " int xshape = output_shape[{0}];\n",
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n"; dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1));
os += fmt::format(
" {0} index = N_ * pos.x + xshape * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
} else { } else {
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n"; os += fmt::format(
" {0} index = pos.x + grid.x * (pos.y + {0}(grid.y) * pos.z);\n",
idx_type);
} }
// Read constant / contiguous inputs in tmps // Read constant / contiguous inputs in tmps
@@ -109,16 +120,19 @@ inline void build_kernel(
if (is_constant(x)) { if (is_constant(x)) {
auto type_str = get_type_string(x.dtype()); auto type_str = get_type_string(x.dtype());
os << " auto tmp_" << xname << " = static_cast<" std::ostringstream ss;
<< get_type_string(x.dtype()) << ">("; print_constant(ss, x);
print_constant(os, x); os += fmt::format(
os << ");\n"; " auto tmp_{0} = static_cast<{1}>({2});\n",
xname,
get_type_string(x.dtype()),
ss.str());
} else if (is_scalar(x)) { } else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " os += fmt::format(
<< xname << "[0];\n"; " {0} tmp_{1} = {1}[0];\n", get_type_string(x.dtype()), xname);
} else if (contiguous) { } else if (contiguous) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " os += fmt::format(
<< xname << "[index];\n"; " {0} tmp_{1} = {1}[index];\n", get_type_string(x.dtype()), xname);
} else { } else {
nc_inputs.push_back(x); nc_inputs.push_back(x);
} }
@@ -127,83 +141,96 @@ inline void build_kernel(
// Initialize the indices for non-contiguous inputs // Initialize the indices for non-contiguous inputs
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]); auto& xname = namer.get_name(nc_inputs[i]);
os += fmt::format(" {0} index_{1} = ", idx_type, xname);
if (ndim == 1) { if (ndim == 1) {
int offset = i * ndim; int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_1(pos.x, " os +=
<< "in_strides[" << offset << "]);\n"; fmt::format("elem_to_loc_1<uint>(pos.x, in_strides[{0}]);\n", offset);
} else if (ndim == 2) { } else if (ndim == 2) {
int offset = i * ndim; int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_2({pos.x, pos.y}, " os += fmt::format(
<< "in_strides + " << offset << ");\n"; "elem_to_loc_2<{0}>({{pos.x, pos.y}}, in_strides + {1});\n",
idx_type,
offset);
} else if (ndim == 3) { } else if (ndim == 3) {
int offset = i * ndim; int offset = i * ndim;
os << " size_t index_" << xname << " = elem_to_loc_3(pos, " os += fmt::format(
<< "in_strides + " << offset << ");\n"; "elem_to_loc_3<{0}>(pos, in_strides + {1});\n", idx_type, offset);
} else if (!dynamic_dims) { } else if (!dynamic_dims) {
int offset = i * ndim; int offset = (i + 1) * ndim;
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[" os += fmt::format(
<< offset + ndim - 1 << "]" "N_ * pos.x * {0}(in_strides[{1}]) + pos.y * {0}(in_strides[{2}]);\n",
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n"; idx_type,
offset - 1,
offset - 2);
} else { } else {
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * " os += fmt::format(
<< i << " + ndim - 1]" "N_ * pos.x * {0}(in_strides[ndim * {1} + ndim - 1]) + pos.y * {0}(in_strides[ndim * {1} + ndim - 2]);\n",
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n"; idx_type,
i);
} }
} }
if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) { if (!nc_inputs.empty() && (ndim > 3 || dynamic_dims)) {
os << " uint zpos = pos.z;\n"; os += " uint zpos = pos.z;\n";
if (dynamic_dims) { if (dynamic_dims) {
os << " for (int d = ndim - 3; d >= 0; --d) {\n"; os += " for (int d = ndim - 3; d >= 0; --d) {\n";
} else { } else {
os << " for (int d = " << ndim - 3 << "; d >= 0; --d) {\n"; os += fmt::format(" for (int d = {0}; d >= 0; --d) {{\n", ndim - 3);
} }
os << " uint l = zpos % output_shape[d];\n"; os += " uint l = zpos % output_shape[d];\n";
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < nc_inputs.size(); ++i) {
auto& xname = namer.get_name(nc_inputs[i]); auto& xname = namer.get_name(nc_inputs[i]);
os << " index_" << xname << " += "; os += fmt::format(" index_{0} += ", xname);
if (dynamic_dims) { if (dynamic_dims) {
os << "l * in_strides[" << i << " * ndim + d];\n"; os +=
fmt::format("l * {0}(in_strides[{1} * ndim + d]);\n", idx_type, i);
} else { } else {
os << "l * in_strides[" << i * ndim << " + d];\n"; os +=
fmt::format("l * {0}(in_strides[{1} + d]);\n", idx_type, i * ndim);
} }
} }
os << " zpos /= output_shape[d];\n }\n"; os += " zpos /= output_shape[d];\n }\n";
} }
// Open per-thread loop // Open per-thread loop
if (work_per_thread > 1) { if (work_per_thread > 1) {
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n"; os +=
" for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
} }
// Read non-contiguous inputs into tmps // Read non-contiguous inputs into tmps
for (int i = 0; i < nc_inputs.size(); ++i) { for (int i = 0; i < nc_inputs.size(); ++i) {
auto& x = nc_inputs[i]; auto& x = nc_inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = " os += fmt::format(
<< xname << "[index_" << xname << "];\n"; " {0} tmp_{1} = {1}[index_{1}];\n", get_type_string(x.dtype()), xname);
} }
// Actually write the computation // Actually write the computation
for (auto& x : tape) { for (auto& x : tape) {
os << " " << get_type_string(x.dtype()) << " tmp_" << namer.get_name(x) os += fmt::format(
<< " = "; " {0} tmp_{1} = ", get_type_string(x.dtype()), namer.get_name(x));
if (is_static_cast(x.primitive())) { if (is_static_cast(x.primitive())) {
os << "static_cast<" << get_type_string(x.dtype()) << ">(tmp_" os += fmt::format(
<< namer.get_name(x.inputs()[0]) << ");\n"; "static_cast<{0}>(tmp_{1});\n",
get_type_string(x.dtype()),
namer.get_name(x.inputs()[0]));
} else { } else {
x.primitive().print(os); std::ostringstream ss;
os << "()("; x.primitive().print(ss);
os += ss.str();
os += "()(";
for (int i = 0; i < x.inputs().size() - 1; i++) { for (int i = 0; i < x.inputs().size() - 1; i++) {
os << "tmp_" << namer.get_name(x.inputs()[i]) << ", "; os += fmt::format("tmp_{0}, ", namer.get_name(x.inputs()[i]));
} }
os << "tmp_" << namer.get_name(x.inputs().back()) << ");\n"; os += fmt::format("tmp_{0});\n", namer.get_name(x.inputs().back()));
} }
} }
// Write the outputs from tmps // Write the outputs from tmps
for (auto& x : outputs) { for (auto& x : outputs) {
os << " " << namer.get_name(x) << "[index] = tmp_" << namer.get_name(x) os += fmt::format(" {0}[index] = tmp_{0};\n", namer.get_name(x));
<< ";\n";
} }
// Increment indices and close per thread loop // Increment indices and close per thread loop
if (work_per_thread > 1) { if (work_per_thread > 1) {
@@ -211,18 +238,18 @@ inline void build_kernel(
auto& x = nc_inputs[i]; auto& x = nc_inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
if (!dynamic_dims) { if (!dynamic_dims) {
os << " index_" << xname << " += " os += fmt::format(
<< "in_strides[" << i * ndim + ndim - 1 << "];\n"; " index_{0} += in_strides[{1}];\n", xname, i * ndim + ndim - 1);
} else { } else {
os << " index_" << xname << " += " os += fmt::format(
<< "in_strides[" << i << " * ndim + ndim - 1];\n"; " index_{0} += in_strides[{1} * ndim + ndim - 1];\n", xname, i);
} }
} }
os << " index++;\n }\n"; os += " index++;\n }\n";
} }
// Finish the kernel // Finish the kernel
os << "}\n"; os += "}\n";
if (cnt > 31) { if (cnt > 31) {
std::ostringstream msg; std::ostringstream msg;
@@ -246,9 +273,9 @@ void Compiled::eval_gpu(
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
auto lib = d.get_library(kernel_lib_, [&]() { auto lib = d.get_library(kernel_lib_, [&]() {
std::ostringstream kernel; std::string kernel = metal::utils();
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops() concatenate(
<< metal::ternary_ops(); kernel, metal::unary_ops(), metal::binary_ops(), metal::ternary_ops());
build_kernel( build_kernel(
kernel, kernel,
kernel_lib_ + "_contiguous", kernel_lib_ + "_contiguous",
@@ -261,7 +288,7 @@ void Compiled::eval_gpu(
/* dynamic_dims = */ false); /* dynamic_dims = */ false);
build_kernel( build_kernel(
kernel, kernel,
kernel_lib_ + "_contiguous_big", kernel_lib_ + "_contiguous_large",
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
@@ -282,7 +309,21 @@ void Compiled::eval_gpu(
/* ndim = */ i, /* ndim = */ i,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
/* use_big_index = */ false, /* use_big_index = */ false,
/* work_per_thread = */ i > 3 ? WORK_PER_THREAD : 1); /* work_per_thread = */ i > 3 ? 2 : 1);
if (i > 1) {
build_kernel(
kernel,
kernel_lib_ + "_strided_" + std::to_string(i) + "_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ i,
/* dynamic_dims = */ false,
/* use_big_index = */ true,
/* work_per_thread = */ i > 3 ? 4 : 1);
}
} }
build_kernel( build_kernel(
kernel, kernel,
@@ -295,20 +336,32 @@ void Compiled::eval_gpu(
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ true, /* dynamic_dims = */ true,
/* use_big_index = */ false, /* use_big_index = */ false,
/* work_per_thread = */ WORK_PER_THREAD); /* work_per_thread = */ 2);
return kernel.str(); build_kernel(
kernel,
kernel_lib_ + "_strided_dynamic_large",
inputs_,
outputs_,
tape_,
constant_ids_,
/* contiguous = */ false,
/* ndim = */ 0,
/* dynamic_dims = */ true,
/* use_big_index = */ true,
/* work_per_thread = */ 4);
return kernel;
}); });
// Figure out which kernel we are using // Figure out which kernel we are using
auto& output_shape = outputs[0].shape(); auto& output_shape = outputs[0].shape();
bool contiguous = compiled_check_contiguity(inputs, output_shape); auto contiguous = compiled_check_contiguity(inputs, output_shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also // Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting. // handle all broadcasting.
std::vector<std::vector<size_t>> initial_strides; std::vector<Strides> initial_strides;
initial_strides.push_back(outputs[0].strides()); initial_strides.push_back(outputs[0].strides());
std::vector<int> shape; Shape shape;
std::vector<std::vector<size_t>> strides; std::vector<Strides> strides;
if (!contiguous) { if (!contiguous) {
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
// Skip constants. // Skip constants.
@@ -323,7 +376,7 @@ void Compiled::eval_gpu(
} }
// Broadcast the inputs to the output shape. // Broadcast the inputs to the output shape.
std::vector<size_t> xstrides; Strides xstrides;
int j = 0; int j = 0;
for (; j < output_shape.size() - x.ndim(); j++) { for (; j < output_shape.size() - x.ndim(); j++) {
if (output_shape[j] == 1) { if (output_shape[j] == 1) {
@@ -349,13 +402,19 @@ void Compiled::eval_gpu(
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX); collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
} }
bool use_2d = false; bool large;
if (contiguous) { if (contiguous) {
size_t max_size = 0; size_t max_size = 0;
for (auto& in : inputs) { for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size()); max_size = std::max(max_size, in.data_size());
} }
use_2d = (max_size > UINT32_MAX); large = (max_size > UINT32_MAX);
} else {
size_t max_size = 0;
for (auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
large = (max_size > UINT32_MAX);
} }
// Get the kernel from the lib // Get the kernel from the lib
@@ -368,17 +427,18 @@ void Compiled::eval_gpu(
} else { } else {
kernel_name += std::to_string(shape.size()); kernel_name += std::to_string(shape.size());
} }
} else if (use_2d) { }
kernel_name += "_big"; if (large) {
kernel_name += "_large";
} }
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Put the inputs in // Put the inputs in
int cnt = 0; int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides int stride_idx = 1; // idx 0 is the output strides
std::vector<size_t> in_strides; Strides in_strides;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue; continue;
@@ -394,8 +454,7 @@ void Compiled::eval_gpu(
} }
} }
if (!in_strides.empty()) { if (!in_strides.empty()) {
compute_encoder->setBytes( compute_encoder.set_vector_bytes(in_strides, cnt++);
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
} }
compiled_allocate_outputs( compiled_allocate_outputs(
@@ -408,14 +467,13 @@ void Compiled::eval_gpu(
// Put the output shape and strides in // Put the output shape and strides in
if (!contiguous) { if (!contiguous) {
compute_encoder->setBytes( compute_encoder.set_vector_bytes(strides[0], cnt++);
strides[0].data(), strides[0].size() * sizeof(size_t), cnt++); compute_encoder.set_vector_bytes(shape, cnt++);
compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), cnt++);
} }
// Put the number of dims in if it is dynamic // Put the number of dims in if it is dynamic
if (dynamic) { if (dynamic) {
compute_encoder->setBytes(&ndim, sizeof(int), cnt++); compute_encoder.set_bytes(ndim, cnt++);
} }
// Launch the kernel // Launch the kernel
@@ -424,15 +482,15 @@ void Compiled::eval_gpu(
MTL::Size group_dims( MTL::Size group_dims(
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
MTL::Size grid_dims = use_2d MTL::Size grid_dims = large
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides()) ? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t rest = outputs[0].size() / (dim0 * dim1); size_t rest = outputs[0].size() / (dim0 * dim1);
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1; int work_per_thread = ndim > 3 ? (large ? 4 : 2) : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
int pow2; int pow2;
@@ -445,7 +503,7 @@ void Compiled::eval_gpu(
} }
auto group_dims = get_block_dims(dim0, dim1, rest, pow2); auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }

View File

@@ -44,27 +44,28 @@ void explicit_gemm_conv_ND_gpu(
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1); compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2); compute_encoder.set_bytes(conv_params, 2);
// Launch unfolding kernel // Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64); size_t tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32); tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x; size_t tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size( MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
// Reshape weight // Reshape weight
std::vector<int> wt_reshape{implicit_K, implicit_N}; Shape wt_reshape{implicit_K, implicit_N};
std::vector<size_t> wt_restride{1, static_cast<size_t>(implicit_K)}; Strides wt_restride{1, implicit_K};
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {}); array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
auto wt_flags = wt.flags(); auto wt_flags = wt.flags();
wt_flags.row_contiguous = false; wt_flags.row_contiguous = false;
@@ -122,33 +123,31 @@ void explicit_gemm_conv_group_ND_gpu(
<< N; << N;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1); compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder->setBytes(&conv_params, sizeof(conv_params), 2); compute_encoder.set_bytes(conv_params, 2);
// Launch unfolding kernel // Launch unfolding kernel
int tgp_x = std::min(conv_params.C, 64); size_t tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32); tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
int tgp_y = 256 / tgp_x; size_t tgp_y = 256 / tgp_x;
MTL::Size group_dims = MTL::Size(tgp_x, tgp_y, 1);
MTL::Size grid_dims = MTL::Size( MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
// Transpose kernel weights so that we can slice them by contiguous chunks // Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups. // of channel groups.
array wt_view( array wt_view(
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {}); {wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
wt_view.copy_shared_buffer( wt_view.copy_shared_buffer(
wt, wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
{wt.strides(0), 1, static_cast<size_t>(C_per_group)},
wt.flags(),
wt.size());
// Materialize // Materialize
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {}); auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
@@ -237,7 +236,7 @@ void slow_conv_2D_gpu(
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
size_t n_pixels = conv_params.oS[0] * conv_params.oS[1]; size_t n_pixels = conv_params.oS[0] * conv_params.oS[1];
@@ -252,8 +251,8 @@ void slow_conv_2D_gpu(
compute_encoder.set_input_array(wt, 1); compute_encoder.set_input_array(wt, 1);
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); compute_encoder.set_bytes(conv_params, 3);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void implicit_gemm_conv_2D_gpu( void implicit_gemm_conv_2D_gpu(
@@ -352,7 +351,7 @@ void implicit_gemm_conv_2D_gpu(
wn, wn,
n_channel_specialization, n_channel_specialization,
small_filter); small_filter);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions // Deduce grid launch dimensions
int tile = 1 << swizzle_log; int tile = 1 << swizzle_log;
@@ -368,11 +367,11 @@ void implicit_gemm_conv_2D_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode params // Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); compute_encoder.set_bytes(conv_params, 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4); compute_encoder.set_bytes(gemm_params, 4);
// Launch kernel // Launch kernel
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void implicit_gemm_conv_2D_general_gpu( void implicit_gemm_conv_2D_general_gpu(
@@ -506,7 +505,7 @@ void implicit_gemm_conv_2D_general_gpu(
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = auto kernel =
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn); get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Deduce grid launch dimensions // Deduce grid launch dimensions
int tile = 1 << swizzle_log; int tile = 1 << swizzle_log;
@@ -523,17 +522,15 @@ void implicit_gemm_conv_2D_general_gpu(
compute_encoder.set_output_array(out, 2); compute_encoder.set_output_array(out, 2);
// Encode params // Encode params
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3); compute_encoder.set_bytes(conv_params, 3);
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4); compute_encoder.set_bytes(gemm_params, 4);
compute_encoder->setBytes(&jump_params, sizeof(Conv2DGeneralJumpParams), 5); compute_encoder.set_bytes(jump_params, 5);
compute_encoder->setBytes( compute_encoder.set_vector_bytes(base_h, 6);
base_h.data(), sizeof(Conv2DGeneralBaseInfo) * base_h.size(), 6); compute_encoder.set_vector_bytes(base_w, 7);
compute_encoder->setBytes(
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
// Launch kernel // Launch kernel
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
void winograd_conv_2D_gpu( void winograd_conv_2D_gpu(
@@ -622,18 +619,18 @@ void winograd_conv_2D_gpu(
<< bc; << bc;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(wt, 0); compute_encoder.set_input_array(wt, 0);
compute_encoder.set_output_array(filt_wg, 1); compute_encoder.set_output_array(filt_wg, 1);
compute_encoder->setBytes(&C_c, sizeof(int), 2); compute_encoder.set_bytes(C_c, 2);
compute_encoder->setBytes(&O_c, sizeof(int), 3); compute_encoder.set_bytes(O_c, 3);
MTL::Size group_dims = MTL::Size(32, bo, 1); MTL::Size group_dims = MTL::Size(32, bo, 1);
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1); MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
// Do input transform // Do input transform
@@ -650,18 +647,17 @@ void winograd_conv_2D_gpu(
<< bc; << bc;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_padded, 0); compute_encoder.set_input_array(in_padded, 0);
compute_encoder.set_output_array(inp_wg, 1); compute_encoder.set_output_array(inp_wg, 1);
compute_encoder->setBytes( compute_encoder.set_bytes(conv_params_updated, 2);
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
// Do batched gemm // Do batched gemm
@@ -698,18 +694,17 @@ void winograd_conv_2D_gpu(
<< bc; << bc;
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(out_wg, 0); compute_encoder.set_input_array(out_wg, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes( compute_encoder.set_bytes(conv_params_updated, 2);
&conv_params_updated, sizeof(MLXConvParams<2>), 2);
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n); MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
} }
} }

View File

@@ -43,13 +43,12 @@ void copy_gpu(const array& in, array& out, CopyType ctype) {
copy_gpu(in, out, ctype, out.primitive().stream()); copy_gpu(in, out, ctype, out.primitive().stream());
} }
template <typename stride_t>
void copy_gpu_inplace( void copy_gpu_inplace(
const array& in, const array& in,
array& out, array& out,
const std::vector<int>& data_shape, const Shape& data_shape,
const std::vector<stride_t>& strides_in_pre, const Strides& strides_in_pre,
const std::vector<stride_t>& strides_out_pre, const Strides& strides_out_pre,
int64_t inp_offset, int64_t inp_offset,
int64_t out_offset, int64_t out_offset,
CopyType ctype, CopyType ctype,
@@ -68,50 +67,52 @@ void copy_gpu_inplace(
/* size_cap = */ INT32_MAX); /* size_cap = */ INT32_MAX);
return std::make_tuple(shape, strides[0], strides[1]); return std::make_tuple(shape, strides[0], strides[1]);
} else { } else {
std::vector<stride_t> e; Strides e{};
return std::make_tuple(std::vector<int>{}, e, e); return std::make_tuple(Shape{}, e, e);
} }
}; };
auto [shape, strides_in_, strides_out_] = maybe_collapse(); auto [shape, strides_in_, strides_out_] = maybe_collapse();
int ndim = shape.size(); int ndim = shape.size();
bool large;
bool use_2d = out.data_size() > UINT32_MAX; if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
// Allow for negative strides
large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
} else {
large = out.data_size() > UINT32_MAX;
}
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
int work_per_thread = 1; int work_per_thread = 1;
std::string kernel_name; std::string kernel_name;
{ switch (ctype) {
std::ostringstream kname; case CopyType::Scalar:
switch (ctype) { kernel_name = (large ? "s2" : "s");
case CopyType::Scalar: break;
kname << (use_2d ? "s2" : "s"); case CopyType::Vector:
break; kernel_name = (large ? "v2" : "v");
case CopyType::Vector: break;
kname << (use_2d ? "v2" : "v"); case CopyType::General:
break; kernel_name = "g";
case CopyType::General: break;
kname << "g"; case CopyType::GeneralGeneral:
break; kernel_name = "gg";
case CopyType::GeneralGeneral: break;
kname << "gg";
break;
}
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
work_per_thread = 4;
kname << "n4";
}
}
kname << "_copy";
kname << type_to_name(in) << type_to_name(out);
kernel_name = kname.str();
} }
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
if (shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
kernel_name += std::to_string(shape.size());
} else {
work_per_thread = large ? 4 : 2;
concatenate(kernel_name, "n", std::to_string(work_per_thread));
}
if (large) {
kernel_name += "large";
}
}
concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out));
auto kernel = get_copy_kernel(d, kernel_name, in, out); auto kernel = get_copy_kernel(d, kernel_name, in, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
bool donate_in = in.data_shared_ptr() == nullptr; bool donate_in = in.data_shared_ptr() == nullptr;
inp_offset *= size_of(in.dtype()); inp_offset *= size_of(in.dtype());
@@ -122,26 +123,26 @@ void copy_gpu_inplace(
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
std::vector<int64_t> strides_in{strides_in_.begin(), strides_in_.end()}; Strides strides_in{strides_in_.begin(), strides_in_.end()};
std::vector<int64_t> strides_out{strides_out_.begin(), strides_out_.end()}; Strides strides_out{strides_out_.begin(), strides_out_.end()};
if (ndim > 3) { if (ndim > 3) {
set_vector_bytes(compute_encoder, shape, ndim, 2); compute_encoder.set_vector_bytes(shape, ndim, 2);
} }
set_vector_bytes(compute_encoder, strides_in, ndim, 3); compute_encoder.set_vector_bytes(strides_in, ndim, 3);
if (ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::GeneralGeneral) {
set_vector_bytes(compute_encoder, strides_out, ndim, 4); compute_encoder.set_vector_bytes(strides_out, ndim, 4);
} }
int dim0 = ndim > 0 ? shape[ndim - 1] : 1; size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
int dim1 = ndim > 1 ? shape[ndim - 2] : 1; size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
size_t data_size = 1; size_t data_size = 1;
for (auto& s : shape) for (auto& s : shape)
data_size *= s; data_size *= s;
int rest = data_size / (dim0 * dim1); size_t rest = data_size / (dim0 * dim1);
if (ndim > MAX_COPY_SPECIALIZED_DIMS) { if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
compute_encoder->setBytes(&ndim, sizeof(int), 5); compute_encoder.set_bytes(ndim, 5);
dim0 = (dim0 + work_per_thread - 1) / work_per_thread; dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
} }
@@ -152,16 +153,16 @@ void copy_gpu_inplace(
auto group_dims = get_block_dims(dim0, dim1, rest); auto group_dims = get_block_dims(dim0, dim1, rest);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} else { } else {
size_t nthreads = out.data_size(); size_t nthreads = out.data_size();
if (thread_group_size > nthreads) { if (thread_group_size > nthreads) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} }
@@ -178,14 +179,13 @@ void copy_gpu_inplace(
void copy_gpu_inplace( void copy_gpu_inplace(
const array& in, const array& in,
array& out, array& out,
const std::vector<int64_t>& istride, const Strides& istride,
int64_t ioffset, int64_t ioffset,
CopyType ctype, CopyType ctype,
const Stream& s) { const Stream& s) {
assert(in.shape() == out.shape()); assert(in.shape() == out.shape());
std::vector<int64_t> ostrides{out.strides().begin(), out.strides().end()};
return copy_gpu_inplace( return copy_gpu_inplace(
in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s); in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s);
} }
void fill_gpu(const array& val, array& out, const Stream& s) { void fill_gpu(const array& val, array& out, const Stream& s) {
@@ -193,13 +193,13 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
return; return;
} }
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool use_2d = out.data_size() > UINT32_MAX; bool large = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::string kernel_name = std::string(use_2d ? "s2" : "s") + "_copy" + std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" +
type_to_name(val) + type_to_name(out); type_to_name(val) + type_to_name(out);
auto kernel = get_copy_kernel(d, kernel_name, val, out); auto kernel = get_copy_kernel(d, kernel_name, val, out);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(val, 0); compute_encoder.set_input_array(val, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
@@ -210,9 +210,9 @@ void fill_gpu(const array& val, array& out, const Stream& s) {
thread_group_size = nthreads; thread_group_size = nthreads;
} }
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides()) MTL::Size grid_dims = large ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -8,13 +8,12 @@
namespace mlx::core { namespace mlx::core {
// Generic copy inplace // Generic copy inplace
template <typename stride_t>
void copy_gpu_inplace( void copy_gpu_inplace(
const array& in, const array& in,
array& out, array& out,
const std::vector<int>& data_shape, const Shape& data_shape,
const std::vector<stride_t>& i_strides, const Strides& i_strides,
const std::vector<stride_t>& o_strides, const Strides& o_strides,
int64_t i_offset, int64_t i_offset,
int64_t o_offset, int64_t o_offset,
CopyType ctype, CopyType ctype,
@@ -32,7 +31,7 @@ void copy_gpu_inplace(
void copy_gpu_inplace( void copy_gpu_inplace(
const array& in, const array& in,
array& out, array& out,
const std::vector<int64_t>& istride, const Strides& istride,
int64_t ioffset, int64_t ioffset,
CopyType ctype, CopyType ctype,
const Stream& s); const Stream& s);

View File

@@ -43,7 +43,7 @@ void CustomKernel::eval_gpu(
d.get_library(lib_name, [this] { return metal::utils() + source_; }); d.get_library(lib_name, [this] { return metal::utils() + source_; });
auto kernel = d.get_kernel(name_, lib); auto kernel = d.get_kernel(name_, lib);
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
int index = 0; int index = 0;
for (int i = 0; i < checked_inputs.size(); i++) { for (int i = 0; i < checked_inputs.size(); i++) {
const array& in = checked_inputs[i]; const array& in = checked_inputs[i];
@@ -53,15 +53,15 @@ void CustomKernel::eval_gpu(
if (in.ndim() > 0) { if (in.ndim() > 0) {
int ndim = in.ndim(); int ndim = in.ndim();
if (shape_info.shape) { if (shape_info.shape) {
set_vector_bytes(compute_encoder, in.shape(), ndim, index); compute_encoder.set_vector_bytes(in.shape(), ndim, index);
index++; index++;
} }
if (shape_info.strides) { if (shape_info.strides) {
set_vector_bytes(compute_encoder, in.strides(), ndim, index); compute_encoder.set_vector_bytes(in.strides(), ndim, index);
index++; index++;
} }
if (shape_info.ndim) { if (shape_info.ndim) {
compute_encoder->setBytes(&ndim, sizeof(int), index); compute_encoder.set_bytes(ndim, index);
index++; index++;
} }
} }
@@ -72,10 +72,11 @@ void CustomKernel::eval_gpu(
} }
const auto [tx, ty, tz] = threadgroup_; const auto [tx, ty, tz] = threadgroup_;
MTL::Size group_dims = MTL::Size(tx, ty, tz);
const auto [gx, gy, gz] = grid_; const auto [gx, gy, gz] = grid_;
MTL::Size group_dims =
MTL::Size(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
MTL::Size grid_dims = MTL::Size(gx, gy, gz); MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
} }

View File

@@ -23,14 +23,18 @@ constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() { auto get_metal_version() {
#if (MLX_METAL_VERSION >= 320) auto get_metal_version_ = []() {
return MTL::LanguageVersion3_2; if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
#elif (MLX_METAL_VERSION >= 310) return MTL::LanguageVersion3_2;
return MTL::LanguageVersion3_1; } else if (__builtin_available(macOS 14, iOS 17, tvOS 17, visionOS 1, *)) {
#else return MTL::LanguageVersion3_1;
return MTL::LanguageVersion3_0; } else {
#endif return MTL::LanguageVersion3_0;
}
};
static auto metal_version_ = get_metal_version_();
return metal_version_;
} }
auto load_device() { auto load_device() {
@@ -171,14 +175,14 @@ void CommandEncoder::maybeInsertBarrier() {
next_outputs_.clear(); next_outputs_.clear();
} }
void CommandEncoder::dispatchThreadgroups( void CommandEncoder::dispatch_threadgroups(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
maybeInsertBarrier(); maybeInsertBarrier();
enc_->dispatchThreadgroups(grid_dims, group_dims); enc_->dispatchThreadgroups(grid_dims, group_dims);
} }
void CommandEncoder::dispatchThreads( void CommandEncoder::dispatch_threads(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
maybeInsertBarrier(); maybeInsertBarrier();
@@ -276,7 +280,7 @@ void Device::end_encoding(int index) {
// - Update the map of outputs to include this command encoder's outputs. // - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence. // - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs // - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unecessary waits // from the map to limit the growth of the map and avoid unnecessary waits
// - Temporaries are a special case as they do not cross command encoder // - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and // boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization. // outputs since they don't need synchronization.
@@ -298,7 +302,7 @@ void Device::end_encoding(int index) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) { if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again. // If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) { if (waiting_on.find(it->second) == waiting_on.end()) {
enc->waitForFence(it->second->fence); enc.wait_for_fence(it->second->fence);
waiting_on.insert(it->second); waiting_on.insert(it->second);
} }
} }
@@ -307,7 +311,7 @@ void Device::end_encoding(int index) {
stream.outputs[out] = stream.fence; stream.outputs[out] = stream.fence;
} }
} }
enc->updateFence(stream.fence->fence); enc.update_fence(stream.fence->fence);
stream.buffer->addCompletedHandler( stream.buffer->addCompletedHandler(
[&stream, [&stream,
waiting_on = std::move(waiting_on), waiting_on = std::move(waiting_on),
@@ -641,21 +645,27 @@ void new_stream(Stream stream) {
std::unordered_map<std::string, std::variant<std::string, size_t>> std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() { device_info() {
auto raw_device = device(default_device()).mtl_device(); auto init_device_info = []()
auto arch = std::string(raw_device->architecture()->name()->utf8String()); -> std::unordered_map<std::string, std::variant<std::string, size_t>> {
auto pool = new_scoped_memory_pool();
auto raw_device = device(default_device()).mtl_device();
auto arch = std::string(raw_device->architecture()->name()->utf8String());
int mib[] = {CTL_HW, HW_MEMSIZE}; int mib[] = {CTL_HW, HW_MEMSIZE};
size_t memsize = 0; size_t memsize = 0;
size_t length = sizeof(memsize); size_t length = sizeof(memsize);
sysctl(mib, 2, &memsize, &length, NULL, 0); sysctl(mib, 2, &memsize, &length, NULL, 0);
return { return {
{"architecture", arch}, {"architecture", arch},
{"max_buffer_length", raw_device->maxBufferLength()}, {"max_buffer_length", raw_device->maxBufferLength()},
{"max_recommended_working_set_size", {"max_recommended_working_set_size",
raw_device->recommendedMaxWorkingSetSize()}, raw_device->recommendedMaxWorkingSetSize()},
{"memory_size", memsize}}; {"memory_size", memsize}};
};
static auto device_info_ = init_device_info();
return device_info_;
} }
} // namespace mlx::core::metal } // namespace mlx::core::metal

View File

@@ -58,16 +58,43 @@ struct CommandEncoder {
CommandEncoder& enc; CommandEncoder& enc;
}; };
MTL::ComputeCommandEncoder* operator->() {
return enc_;
}
void set_input_array(const array& a, int idx, int64_t offset = 0); void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0); void set_output_array(array& a, int idx, int64_t offset = 0);
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims); void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier(); void maybeInsertBarrier();
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
enc_->setComputePipelineState(kernel);
}
void wait_for_fence(MTL::Fence* fence) {
enc_->waitForFence(fence);
}
void update_fence(MTL::Fence* fence) {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
template <typename T>
void set_bytes(const T* v, int n, int idx) {
return enc_->setBytes(v, n * sizeof(T), idx);
}
template <typename T>
void set_bytes(const T& v, int idx) {
return enc_->setBytes(&v, sizeof(T), idx);
}
ConcurrentContext start_concurrent() { ConcurrentContext start_concurrent() {
return ConcurrentContext(*this); return ConcurrentContext(*this);
} }

View File

@@ -363,7 +363,7 @@ void multi_upload_bluestein_fft(
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n); auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
// Broadcast w_q and w_k to the batch size // Broadcast w_q and w_k to the batch size
std::vector<size_t> b_strides(in.ndim(), 0); Strides b_strides(in.ndim(), 0);
b_strides[axis] = 1; b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {}); array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {}); array w_q_broadcast({}, complex64, nullptr, {});
@@ -386,8 +386,8 @@ void multi_upload_bluestein_fft(
copies.push_back(slice_temp); copies.push_back(slice_temp);
copies.push_back(conj_temp); copies.push_back(conj_temp);
std::vector<int> rstarts(in.ndim(), 0); Shape rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1); Shape rstrides(in.ndim(), 1);
rstarts[axis] = in.shape(axis) - back_offset; rstarts[axis] = in.shape(axis) - back_offset;
rstrides[axis] = -1; rstrides[axis] = -1;
unary_op_gpu({in}, conj_temp, "Conjugate", s); unary_op_gpu({in}, conj_temp, "Conjugate", s);
@@ -431,19 +431,19 @@ void multi_upload_bluestein_fft(
s); s);
int offset = plan.bluestein_n - (2 * n - 1); int offset = plan.bluestein_n - (2 * n - 1);
std::vector<int> starts(in.ndim(), 0); Shape starts(in.ndim(), 0);
std::vector<int> strides(in.ndim(), 1); Shape strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n; starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s); slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s); binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) { if (real && !inverse) {
std::vector<int> rstarts(in.ndim(), 0); Shape rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1); Shape rstrides(in.ndim(), 1);
slice_gpu(temp1, out, rstarts, strides, s); slice_gpu(temp1, out, rstarts, strides, s);
} else if (real && inverse) { } else if (real && inverse) {
std::vector<size_t> b_strides(in.ndim(), 0); Strides b_strides(in.ndim(), 0);
auto inv_n = array({1.0f / n}, {1}, float32); auto inv_n = array({1.0f / n}, {1}, float32);
array temp_float(out.shape(), out.dtype(), nullptr, {}); array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float); copies.push_back(temp_float);
@@ -531,8 +531,8 @@ void fft_op(
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy(x.shape(), x.dtype(), nullptr, {});
std::vector<size_t> strides; Strides strides;
size_t cur_stride = x.shape(axis); int64_t cur_stride = x.shape(axis);
for (int a = 0; a < x.ndim(); a++) { for (int a = 0; a < x.ndim(); a++) {
if (a == axis) { if (a == axis) {
strides.push_back(1); strides.push_back(1);
@@ -699,7 +699,7 @@ void fft_op(
auto kernel = auto kernel =
get_fft_kernel(d, base_name, hash_name, func_consts, template_def); get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in_contiguous, 0); compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
@@ -711,9 +711,9 @@ void fft_op(
compute_encoder.set_input_array(w_q, 2); // w_q compute_encoder.set_input_array(w_q, 2); // w_q
compute_encoder.set_input_array(w_k, 3); // w_k compute_encoder.set_input_array(w_k, 3); // w_k
compute_encoder->setBytes(&n, sizeof(int), 4); compute_encoder.set_bytes(n, 4);
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5); compute_encoder.set_bytes(plan.bluestein_n, 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6); compute_encoder.set_bytes(total_batch_size, 6);
} else if (plan.rader_n > 1) { } else if (plan.rader_n > 1) {
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s); auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
copies.push_back(b_q); copies.push_back(b_q);
@@ -723,22 +723,22 @@ void fft_op(
compute_encoder.set_input_array(b_q, 2); compute_encoder.set_input_array(b_q, 2);
compute_encoder.set_input_array(g_q, 3); compute_encoder.set_input_array(g_q, 3);
compute_encoder.set_input_array(g_minus_q, 4); compute_encoder.set_input_array(g_minus_q, 4);
compute_encoder->setBytes(&n, sizeof(int), 5); compute_encoder.set_bytes(n, 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6); compute_encoder.set_bytes(total_batch_size, 6);
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7); compute_encoder.set_bytes(plan.rader_n, 7);
} else if (four_step_params.required) { } else if (four_step_params.required) {
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2); compute_encoder.set_bytes(four_step_params.n1, 2);
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3); compute_encoder.set_bytes(four_step_params.n2, 3);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4); compute_encoder.set_bytes(total_batch_size, 4);
} else { } else {
compute_encoder->setBytes(&n, sizeof(int), 2); compute_encoder.set_bytes(n, 2);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3); compute_encoder.set_bytes(total_batch_size, 3);
} }
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft); auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
auto grid_dims = auto grid_dims =
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft); MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index); d.add_temporaries(std::move(copies), s.index);
@@ -777,7 +777,7 @@ void nd_fft_op(
// Mirror np.fft.(i)rfftn and perform a real transform // Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis. // only on the final axis.
bool step_real = (real && index == axes.size() - 1); bool step_real = (real && index == axes.size() - 1);
int step_shape = inverse ? out.shape(axis) : in.shape(axis); auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2]; array& out_arr = i == 0 ? out : temp_arrs[i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);

View File

@@ -137,14 +137,14 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup()); assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2); compute_encoder.set_bytes(scale, 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1); MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1); MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
}; };
if (m > 1) { if (m > 1) {

View File

@@ -53,27 +53,31 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
int idx_ndim = nidx ? inputs[1].ndim() : 0; int idx_ndim = nidx ? inputs[1].ndim() : 0;
size_t ndim = src.ndim(); size_t ndim = src.ndim();
std::string lib_name; bool large_index = nidx && inputs[1].size() > UINT32_MAX;
std::string kernel_name; bool large_src = src.size() > UINT32_MAX;
bool large_out = out.size() > UINT32_MAX;
bool large = large_index || large_src || large_out;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
{ std::string kernel_name = fmt::format(
std::ostringstream kname; "gather{0}{1}_{2}_{3}_{4}",
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx type_to_name(out),
<< "_" << idx_ndim; idx_type_name,
lib_name = kname.str(); nidx,
kernel_name = lib_name; idx_ndim,
} large ? "int64_t" : "uint");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::gather(); kernel_source += metal::gather();
std::string out_type_str = get_type_string(out.dtype()); std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str = std::string idx_type_str =
nidx ? get_type_string(inputs[1].dtype()) : "bool"; nidx ? get_type_string(inputs[1].dtype()) : "bool";
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
// Index dimension specializations // Index dimension specializations
kernel_source << fmt::format( kernel_source += fmt::format(
gather_kernels, gather_kernels,
type_to_name(out) + idx_type_name, type_to_name(out) + idx_type_name,
out_type_str, out_type_str,
@@ -81,13 +85,14 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
nidx, nidx,
idx_args, idx_args,
idx_arr, idx_arr,
idx_ndim); idx_ndim,
return kernel_source.str(); large ? "int64_t" : "uint");
return kernel_source;
}); });
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kernel_name, lib); auto kernel = d.get_kernel(kernel_name, lib);
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
size_t slice_size = 1; size_t slice_size = 1;
for (auto s : slice_sizes_) { for (auto s : slice_sizes_) {
@@ -131,20 +136,20 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(out, 1);
// Set source info // Set source info
set_vector_bytes(compute_encoder, src.shape(), 2); compute_encoder.set_vector_bytes(src.shape(), 2);
set_vector_bytes(compute_encoder, src.strides(), 3); compute_encoder.set_vector_bytes(src.strides(), 3);
compute_encoder->setBytes(&ndim, sizeof(size_t), 4); compute_encoder.set_bytes(ndim, 4);
set_vector_bytes(compute_encoder, slice_sizes_, 5); compute_encoder.set_vector_bytes(slice_sizes_, 5);
set_vector_bytes(compute_encoder, axes_, 6); compute_encoder.set_vector_bytes(axes_, 6);
// Set index info // Set index info
// //
// We don't need to check for empty idx_shapes because gather has a // We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization // idx_ndim == 0 specialization
set_vector_bytes(compute_encoder, idx_shapes, 7); compute_encoder.set_vector_bytes(idx_shapes, 7);
set_vector_bytes(compute_encoder, idx_strides, 8); compute_encoder.set_vector_bytes(idx_strides, 8);
set_vector_bytes(compute_encoder, idx_contigs, 9); compute_encoder.set_vector_bytes(idx_contigs, 9);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 10); compute_encoder.set_bytes(idx_ndim, 10);
// Set index buffers // Set index buffers
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
@@ -152,7 +157,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
// Launch grid // Launch grid
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -209,8 +214,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nwork = 32; nwork = 32;
} }
std::string lib_name;
std::string kernel_name;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
std::string op_name; std::string op_name;
switch (reduce_type_) { switch (reduce_type_) {
@@ -231,18 +234,24 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
break; break;
} }
auto upd_contig = upd.flags().row_contiguous; auto upd_contig = upd.flags().row_contiguous;
{ bool large_out = out.size() > UINT32_MAX;
std::ostringstream kname; bool large_idx = nidx && (inputs[1].size() > UINT32_MAX);
kname << "scatter" << type_to_name(out) << idx_type_name; bool large_upd = upd.size() > UINT32_MAX;
kname << "_" << op_name << "_" << nidx << "_" bool large = large_out || large_idx || large_upd;
<< (upd_contig ? "updc_true" : "updc_false") << "_nwork" << nwork; std::string kernel_name = fmt::format(
lib_name = kname.str(); "scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
kernel_name = kname.str(); type_to_name(out),
} idx_type_name,
op_name,
nidx,
upd_contig ? "updc_true" : "updc_false",
nwork,
large ? "int64_t" : "uint");
std::string lib_name = kernel_name;
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::reduce_utils() concatenate(kernel_source, metal::reduce_utils(), metal::scatter());
<< metal::scatter();
std::string out_type_str = get_type_string(out.dtype()); std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str = std::string idx_type_str =
@@ -270,7 +279,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx); auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
kernel_source << fmt::format( kernel_source += fmt::format(
scatter_kernels, scatter_kernels,
type_to_name(out) + idx_type_name + "_" + op_name, type_to_name(out) + idx_type_name + "_" + op_name,
out_type_str, out_type_str,
@@ -280,8 +289,9 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_args, idx_args,
idx_arr, idx_arr,
upd_contig, upd_contig,
nwork); nwork,
return kernel_source.str(); large ? "int64_t" : "uint");
return kernel_source;
}); });
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
@@ -289,7 +299,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
size_t nthreads = upd.size(); size_t nthreads = upd.size();
compute_encoder->setComputePipelineState(kernel); compute_encoder.set_compute_pipeline_state(kernel);
// Set all the buffers // Set all the buffers
compute_encoder.set_input_array(upd, 1); compute_encoder.set_input_array(upd, 1);
@@ -302,8 +312,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
upd_size *= upd.shape(i); upd_size *= upd.shape(i);
} }
// Collect all idx shapes and strides into one place // Collect all idx shapes and strides into one place
std::vector<int> idx_shapes; Shape idx_shapes;
std::vector<size_t> idx_strides; Strides idx_strides;
// To access .data() use char instead of bool // To access .data() use char instead of bool
// bool is 1 byte in Metal so this is safe // bool is 1 byte in Metal so this is safe
std::vector<char> idx_contigs; std::vector<char> idx_contigs;
@@ -322,30 +332,30 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
if (upd_ndim == 0) { if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain // Need placeholders so Metal doesn't compalain
int shape_ = 0; int shape_ = 0;
size_t stride_ = 0; int64_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3); compute_encoder.set_bytes(shape_, 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4); compute_encoder.set_bytes(stride_, 4);
} else { } else {
set_vector_bytes(compute_encoder, upd.shape(), 3); compute_encoder.set_vector_bytes(upd.shape(), 3);
set_vector_bytes(compute_encoder, upd.strides(), 4); compute_encoder.set_vector_bytes(upd.strides(), 4);
} }
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); compute_encoder.set_bytes(upd_ndim, 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); compute_encoder.set_bytes(upd_size, 6);
// Set output info // Set output info
size_t out_ndim = out.ndim(); size_t out_ndim = out.ndim();
if (out_ndim == 0) { if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain // Need placeholders so Metal doesn't compalain
int shape_ = 0; int shape_ = 0;
size_t stride_ = 0; int64_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7); compute_encoder.set_bytes(shape_, 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8); compute_encoder.set_bytes(stride_, 8);
} else { } else {
set_vector_bytes(compute_encoder, out.shape(), 7); compute_encoder.set_vector_bytes(out.shape(), 7);
set_vector_bytes(compute_encoder, out.strides(), 8); compute_encoder.set_vector_bytes(out.strides(), 8);
} }
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); compute_encoder.set_bytes(out_ndim, 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); compute_encoder.set_vector_bytes(axes_, 10);
// Set index info // Set index info
if (idx_ndim == 0) { if (idx_ndim == 0) {
@@ -355,11 +365,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
idx_strides.push_back(0); idx_strides.push_back(0);
idx_contigs.push_back(false); idx_contigs.push_back(false);
} }
set_vector_bytes(compute_encoder, idx_shapes, 11); compute_encoder.set_vector_bytes(idx_shapes, 11);
set_vector_bytes(compute_encoder, idx_strides, 12); compute_encoder.set_vector_bytes(idx_strides, 12);
set_vector_bytes(compute_encoder, idx_contigs, 13); compute_encoder.set_vector_bytes(idx_contigs, 13);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 14); compute_encoder.set_bytes(idx_ndim, 14);
compute_encoder->setBytes(&idx_size, sizeof(size_t), 15); compute_encoder.set_bytes(idx_size, 15);
// Set index buffers // Set index buffers
for (int i = 0; i < nidx; ++i) { for (int i = 0; i < nidx; ++i) {
@@ -375,7 +385,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads"); throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
} }
MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1); MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -11,13 +11,13 @@ gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn
const constant int& marix_ld [[buffer(6)]], const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]], const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]], const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]], const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]], const constant int64_t* matrix_batch_stride [[buffer(12)]],
const device {outm_t}* out_mask [[buffer(20)]], const device {outm_t}* out_mask [[buffer(20)]],
const device {opm_t}* mat_mask [[buffer(21)]], const device {opm_t}* mat_mask [[buffer(21)]],
const device {opm_t}* vec_mask [[buffer(22)]], const device {opm_t}* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]], const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]], const constant int64_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],

View File

@@ -1,16 +1,16 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
constexpr std::string_view gather_kernels = R"( constexpr std::string_view gather_kernels = R"(
[[kernel]] void gather{0}_{3}_{6}( [[kernel]] void gather{0}_{3}_{6}_{7}(
const device {1}* src [[buffer(0)]], const device {1}* src [[buffer(0)]],
device {1}* out [[buffer(1)]], device {1}* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]], const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]], const constant int64_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]], const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]], const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]], const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]], const constant int* idx_shapes [[buffer(7)]],
const constant size_t* idx_strides [[buffer(8)]], const constant int64_t* idx_strides [[buffer(8)]],
const constant bool* idx_contigs [[buffer(9)]], const constant bool* idx_contigs [[buffer(9)]],
const constant int& idx_ndim [[buffer(10)]], const constant int& idx_ndim [[buffer(10)]],
{4} {4}
@@ -19,7 +19,7 @@ constexpr std::string_view gather_kernels = R"(
Indices<{2}, {3}> idxs{{ Indices<{2}, {3}> idxs{{
{{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return gather_impl<{1}, {2}, {3}, {6}>( return gather_impl<{1}, {2}, {3}, {6}, {7}>(
src, src,
out, out,
src_shape, src_shape,
@@ -34,19 +34,19 @@ constexpr std::string_view gather_kernels = R"(
)"; )";
constexpr std::string_view scatter_kernels = R"( constexpr std::string_view scatter_kernels = R"(
[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}( [[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
const device {1}* updates [[buffer(1)]], const device {1}* updates [[buffer(1)]],
device mlx_atomic<{1}>* out [[buffer(2)]], device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* upd_shape [[buffer(3)]], const constant int* upd_shape [[buffer(3)]],
const constant size_t* upd_strides [[buffer(4)]], const constant int64_t* upd_strides [[buffer(4)]],
const constant size_t& upd_ndim [[buffer(5)]], const constant size_t& upd_ndim [[buffer(5)]],
const constant size_t& upd_size [[buffer(6)]], const constant size_t& upd_size [[buffer(6)]],
const constant int* out_shape [[buffer(7)]], const constant int* out_shape [[buffer(7)]],
const constant size_t* out_strides [[buffer(8)]], const constant int64_t* out_strides [[buffer(8)]],
const constant size_t& out_ndim [[buffer(9)]], const constant size_t& out_ndim [[buffer(9)]],
const constant int* axes [[buffer(10)]], const constant int* axes [[buffer(10)]],
const constant int* idx_shapes [[buffer(11)]], const constant int* idx_shapes [[buffer(11)]],
const constant size_t* idx_strides [[buffer(12)]], const constant int64_t* idx_strides [[buffer(12)]],
const constant bool* idx_contigs [[buffer(13)]], const constant bool* idx_contigs [[buffer(13)]],
const constant int& idx_ndim [[buffer(14)]], const constant int& idx_ndim [[buffer(14)]],
const constant size_t& idx_size [[buffer(15)]], const constant size_t& idx_size [[buffer(15)]],
@@ -54,7 +54,7 @@ constexpr std::string_view scatter_kernels = R"(
uint2 gid [[thread_position_in_grid]]) {{ uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}>( return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
updates, updates,
out, out,
upd_shape, upd_shape,

View File

@@ -10,12 +10,12 @@ template [[host_name("{name}")]]
const constant GEMMParams* params [[buffer(4)]], const constant GEMMParams* params [[buffer(4)]],
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
const constant int* batch_shape [[buffer(6)]], const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]], const constant int64_t* batch_strides [[buffer(7)]],
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]], const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]], const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]], const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]], const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]], const constant int64_t* operand_strides [[buffer(14), function_constant(do_gather)]],
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]], const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]],
@@ -43,7 +43,7 @@ block_masked_gemm<
device {itype}* D [[buffer(3)]], device {itype}* D [[buffer(3)]],
const constant GEMMParams* params [[buffer(4)]], const constant GEMMParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]], const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]], const constant int64_t* batch_strides [[buffer(7)]],
const device {outmasktype}* out_mask [[buffer(10)]], const device {outmasktype}* out_mask [[buffer(10)]],
const device {opmasktype}* lhs_mask [[buffer(11)]], const device {opmasktype}* lhs_mask [[buffer(11)]],
const device {opmasktype}* rhs_mask [[buffer(12)]], const device {opmasktype}* rhs_mask [[buffer(12)]],

View File

@@ -1,5 +1,4 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h" #include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/gemv_masked.h" #include "mlx/backend/metal/jit/gemv_masked.h"
@@ -46,25 +45,27 @@ MTL::ComputePipelineState* get_unary_kernel(
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
auto in_t = get_type_string(in_type); auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type); auto out_t = get_type_string(out_type);
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::unary_ops() << metal::unary(); concatenate(kernel_source, metal::unary_ops(), metal::unary());
kernel_source << get_template_definition( kernel_source +=
"v_" + lib_name, "unary_v", in_t, out_t, op); get_template_definition("v_" + lib_name, "unary_v", in_t, out_t, op);
kernel_source << get_template_definition( kernel_source +=
"v2_" + lib_name, "unary_v2", in_t, out_t, op); get_template_definition("v2_" + lib_name, "unary_v2", in_t, out_t, op);
kernel_source << get_template_definition( kernel_source += get_template_definition(
"gn4_" + lib_name, "unary_g", in_t, out_t, op, 4); "gn1_" + lib_name, "unary_g", in_t, out_t, op, 1, "uint");
return kernel_source.str(); kernel_source += get_template_definition(
"gn4large_" + lib_name, "unary_g", in_t, out_t, op, 4);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
void add_binary_kernels( void append_binary_kernels(
const std::string lib_name, const std::string lib_name,
Dtype in_type, Dtype in_type,
Dtype out_type, Dtype out_type,
const std::string op, const std::string op,
std::ostringstream& kernel_source) { std::string& kernel_source) {
const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 10> kernel_types = {{
{"ss", "binary_ss"}, {"ss", "binary_ss"},
{"vs", "binary_vs"}, {"vs", "binary_vs"},
@@ -74,26 +75,24 @@ void add_binary_kernels(
{"sv2", "binary_sv2"}, {"sv2", "binary_sv2"},
{"vv2", "binary_vv2"}, {"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"}, {"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"}, {"g2large", "binary_g_nd2"},
{"g3", "binary_g_nd3"}, {"g3large", "binary_g_nd3"},
}}; }};
auto in_t = get_type_string(in_type);
auto out_t = get_type_string(out_type);
for (auto& [name, func] : kernel_types) { for (auto& [name, func] : kernel_types) {
std::string template_def; kernel_source +=
template_def = get_template_definition( get_template_definition(name + "_" + lib_name, func, in_t, out_t, op);
name + "_" + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
kernel_source << template_def;
} }
kernel_source << get_template_definition( kernel_source += get_template_definition(
"gn4_" + lib_name, "g2_" + lib_name, "binary_g_nd2", in_t, out_t, op, "uint");
"binary_g", kernel_source += get_template_definition(
get_type_string(in_type), "g3_" + lib_name, "binary_g_nd3", in_t, out_t, op, "uint");
get_type_string(out_type), kernel_source += get_template_definition(
op, "gn2_" + lib_name, "binary_g", in_t, out_t, op, 2, "uint");
4); kernel_source += get_template_definition(
"gn4large_" + lib_name, "binary_g", in_t, out_t, op, 4);
} }
MTL::ComputePipelineState* get_binary_kernel( MTL::ComputePipelineState* get_binary_kernel(
@@ -104,10 +103,11 @@ MTL::ComputePipelineState* get_binary_kernel(
const std::string op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary(); kernel_source = metal::utils();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source); concatenate(kernel_source, metal::binary_ops(), metal::binary());
return kernel_source.str(); append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -120,11 +120,10 @@ MTL::ComputePipelineState* get_binary_two_kernel(
const std::string op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::binary_ops() concatenate(kernel_source, metal::binary_ops(), metal::binary_two());
<< metal::binary_two(); append_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source); return kernel_source;
return kernel_source.str();
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -136,24 +135,29 @@ MTL::ComputePipelineState* get_ternary_kernel(
const std::string op) { const std::string op) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source; auto t_str = get_type_string(type);
std::string kernel_source = metal::utils();
concatenate(kernel_source, metal::ternary_ops(), metal::ternary());
const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{ const std::array<std::pair<std::string, std::string>, 5> kernel_types = {{
{"v", "ternary_v"}, {"v", "ternary_v"},
{"v2", "ternary_v2"}, {"v2", "ternary_v2"},
{"g1", "ternary_g_nd1"}, {"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"}, {"g2large", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"}, {"g3large", "ternary_g_nd3"},
}}; }};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto& [name, func] : kernel_types) { for (auto& [name, func] : kernel_types) {
std::string template_def; kernel_source +=
template_def = get_template_definition( get_template_definition(name + "_" + lib_name, func, t_str, op);
name + "_" + lib_name, func, get_type_string(type), op);
kernel_source << template_def;
} }
kernel_source << get_template_definition( kernel_source += get_template_definition(
"gn4_" + lib_name, "ternary_g", get_type_string(type), op, 4); "g2_" + lib_name, "ternary_g_nd2", t_str, op, "uint");
return kernel_source.str(); kernel_source += get_template_definition(
"g3_" + lib_name, "ternary_g_nd3", t_str, op, "uint");
kernel_source += get_template_definition(
"gn2_" + lib_name, "ternary_g", t_str, op, 2, "uint");
kernel_source += get_template_definition(
"gn4large_" + lib_name, "ternary_g", t_str, op, 4);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -165,31 +169,43 @@ MTL::ComputePipelineState* get_copy_kernel(
const array& out) { const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name, [&]() { auto lib = d.get_library(lib_name, [&]() {
std::ostringstream kernel_source; std::string kernel_source = metal::utils();
kernel_source += metal::copy();
auto in_type = get_type_string(in.dtype()); auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype()); auto out_type = get_type_string(out.dtype());
kernel_source << metal::utils() << metal::copy() kernel_source +=
<< get_template_definition( get_template_definition("s_" + lib_name, "copy_s", in_type, out_type);
"s_" + lib_name, "copy_s", in_type, out_type) kernel_source +=
<< get_template_definition( get_template_definition("v_" + lib_name, "copy_v", in_type, out_type);
"v_" + lib_name, "copy_v", in_type, out_type) kernel_source += get_template_definition(
<< get_template_definition( "g1_" + lib_name, "copy_g_nd1", in_type, out_type);
"g1_" + lib_name, "copy_g_nd1", in_type, out_type) kernel_source += get_template_definition(
<< get_template_definition( "g2_" + lib_name, "copy_g_nd2", in_type, out_type, "int");
"g2_" + lib_name, "copy_g_nd2", in_type, out_type) kernel_source += get_template_definition(
<< get_template_definition( "g3_" + lib_name, "copy_g_nd3", in_type, out_type, "int");
"g3_" + lib_name, "copy_g_nd3", in_type, out_type) kernel_source += get_template_definition(
<< get_template_definition( "gn2_" + lib_name, "copy_g", in_type, out_type, 2, "int");
"gn4_" + lib_name, "copy_g", in_type, out_type, 4) kernel_source += get_template_definition(
<< get_template_definition( "gg1_" + lib_name, "copy_gg_nd1", in_type, out_type);
"gg1_" + lib_name, "copy_gg_nd1", in_type, out_type) kernel_source += get_template_definition(
<< get_template_definition( "gg2_" + lib_name, "copy_gg_nd2", in_type, out_type, "int");
"gg2_" + lib_name, "copy_gg_nd2", in_type, out_type) kernel_source += get_template_definition(
<< get_template_definition( "gg3_" + lib_name, "copy_gg_nd3", in_type, out_type, "int");
"gg3_" + lib_name, "copy_gg_nd3", in_type, out_type) kernel_source += get_template_definition(
<< get_template_definition( "ggn2_" + lib_name, "copy_gg", in_type, out_type, 2, "int");
"ggn4_" + lib_name, "copy_gg", in_type, out_type, 4); kernel_source += get_template_definition(
return kernel_source.str(); "g2large_" + lib_name, "copy_g_nd2", in_type, out_type);
kernel_source += get_template_definition(
"g3large_" + lib_name, "copy_g_nd3", in_type, out_type);
kernel_source += get_template_definition(
"gn4large_" + lib_name, "copy_g", in_type, out_type, 4);
kernel_source += get_template_definition(
"gg2large_" + lib_name, "copy_gg_nd2", in_type, out_type);
kernel_source += get_template_definition(
"gg3large_" + lib_name, "copy_gg_nd3", in_type, out_type);
kernel_source += get_template_definition(
"ggn4large_" + lib_name, "copy_gg", in_type, out_type, 4);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -321,17 +337,17 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const array& out) { const Dtype& out_type) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::ostringstream kernel_source;
std::string op_type = op_name; std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]); op_type[0] = std::toupper(op_name[0]);
auto out_type = get_type_string(out.dtype()); auto out_t = get_type_string(out_type);
std::string op = op_type + "<" + out_type + ">"; std::string op = op_type + "<" + out_t + ">";
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); std::string kernel_source = metal::utils();
kernel_source << get_template_definition( kernel_source += metal::reduce_utils();
kernel_name, func_name, out_type, op); kernel_source += metal::reduce();
return kernel_source.str(); kernel_source += get_template_definition(kernel_name, func_name, out_t, op);
return kernel_source;
}); });
return d.get_kernel(kernel_name, lib); return d.get_kernel(kernel_name, lib);
} }
@@ -341,30 +357,31 @@ MTL::ComputePipelineState* get_reduce_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const array& in, const Dtype& in_type,
const array& out, const Dtype& out_type,
const std::string& idx_t,
int ndim /* = -1 */, int ndim /* = -1 */,
int bm /* = -1 */, int bm /* = -1 */,
int bn /* = -1 */) { int bn /* = -1 */) {
auto lib = d.get_library(kernel_name, [&]() { auto lib = d.get_library(kernel_name, [&]() {
std::string op_type = op_name; std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]); op_type[0] = std::toupper(op_name[0]);
std::ostringstream kernel_source; auto in_t = get_type_string(in_type);
auto in_type = get_type_string(in.dtype()); auto out_t = get_type_string(out_type);
auto out_type = get_type_string(out.dtype()); std::string op = op_type + "<" + out_t + ">";
std::string op = op_type + "<" + out_type + ">"; std::string kernel_source = metal::utils();
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce(); concatenate(kernel_source, metal::reduce_utils(), metal::reduce());
if (bm >= 0) { if (bm >= 0) {
kernel_source << get_template_definition( kernel_source += get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim, bm, bn); kernel_name, func_name, in_t, out_t, op, idx_t, ndim, bm, bn);
} else if (ndim >= 0) { } else if (ndim >= 0) {
kernel_source << get_template_definition( kernel_source += get_template_definition(
kernel_name, func_name, in_type, out_type, op, ndim); kernel_name, func_name, in_t, out_t, op, idx_t, ndim);
} else { } else {
kernel_source << get_template_definition( kernel_source += get_template_definition(
kernel_name, func_name, in_type, out_type, op); kernel_name, func_name, in_t, out_t, op, idx_t);
} }
return kernel_source.str(); return kernel_source;
}); });
auto st = d.get_kernel(kernel_name, lib); auto st = d.get_kernel(kernel_name, lib);
return st; return st;

View File

@@ -81,15 +81,16 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const array& out); const Dtype& out_type);
MTL::ComputePipelineState* get_reduce_kernel( MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d, metal::Device& d,
const std::string& kernel_name, const std::string& kernel_name,
const std::string& func_name, const std::string& func_name,
const std::string& op_name, const std::string& op_name,
const array& in, const Dtype& in_type,
const array& out, const Dtype& out_type,
const std::string& idx_t,
int ndim = -1, int ndim = -1,
int bm = -1, int bm = -1,
int bn = -1); int bn = -1);

View File

@@ -1,13 +1,27 @@
set(BASE_HEADERS bf16.h bf16_math.h complex.h defines.h expm1f.h utils.h) set(BASE_HEADERS
metal_3_1/bf16.h
metal_3_0/bf16.h
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h)
function(build_kernel_base TARGET SRCFILE DEPS) function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math) set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
if(MLX_METAL_DEBUG) if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
endif() endif()
if(MLX_METAL_VERSION GREATER_EQUAL 310)
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_1)
else()
set(VERSION_INCLUDES
${PROJECT_SOURCE_DIR}/mlx/backend/metal/kernels/metal_3_0)
endif()
add_custom_command( add_custom_command(
COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE} COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air -I${PROJECT_SOURCE_DIR} -I${VERSION_INCLUDES} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air" COMMENT "Building ${TARGET}.air"
@@ -30,9 +44,7 @@ build_kernel(layer_norm)
build_kernel(random) build_kernel(random)
build_kernel(rms_norm) build_kernel(rms_norm)
build_kernel(rope) build_kernel(rope)
build_kernel( build_kernel(scaled_dot_product_attention sdpa_vector.h)
scaled_dot_product_attention scaled_dot_product_attention_params.h
sdpa_vector.h steel/defines.h steel/gemm/transforms.h steel/utils.h)
set(STEEL_HEADERS set(STEEL_HEADERS
steel/defines.h steel/defines.h
@@ -54,6 +66,24 @@ set(STEEL_HEADERS
steel/utils/type_traits.h steel/utils/type_traits.h
steel/utils/integral_constant.h) steel/utils/integral_constant.h)
set(STEEL_ATTN_HEADERS
steel/defines.h
steel/utils.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/utils/type_traits.h
steel/utils/integral_constant.h
steel/attn/attn.h
steel/attn/loader.h
steel/attn/mma.h
steel/attn/params.h
steel/attn/transforms.h
steel/attn/kernels/steel_attention.h)
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
if(NOT MLX_METAL_JIT) if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h) build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h) build_kernel(binary binary.h binary_ops.h)

View File

@@ -1,7 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/arange.h" #include "mlx/backend/metal/kernels/arange.h"
#define instantiate_arange(tname, type) \ #define instantiate_arange(tname, type) \

View File

@@ -75,10 +75,10 @@ template <typename T, typename Op, int N_READS = 4>
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device uint32_t* out [[buffer(1)]], device uint32_t* out [[buffer(1)]],
const constant int* shape [[buffer(2)]], const constant int* shape [[buffer(2)]],
const constant size_t* in_strides [[buffer(3)]], const constant int64_t* in_strides [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]], const constant int64_t* out_strides [[buffer(4)]],
const constant size_t& ndim [[buffer(5)]], const constant size_t& ndim [[buffer(5)]],
const constant size_t& axis_stride [[buffer(6)]], const constant int64_t& axis_stride [[buffer(6)]],
const constant size_t& axis_size [[buffer(7)]], const constant size_t& axis_size [[buffer(7)]],
uint gid [[thread_position_in_grid]], uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],

View File

@@ -2,8 +2,6 @@
#pragma once #pragma once
#include "mlx/backend/metal/kernels/bf16.h"
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16 // Metal math for bfloat16
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@@ -369,18 +367,6 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \ return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
} }
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal { namespace metal {
instantiate_metal_simd_comm_funcs( instantiate_metal_simd_comm_funcs(

View File

@@ -43,7 +43,7 @@ template <typename T, typename U, typename Op>
device U* c, device U* c,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y); int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[0], b[offset]); c[offset] = Op()(a[0], b[offset]);
} }
@@ -54,7 +54,7 @@ template <typename T, typename U, typename Op>
device U* c, device U* c,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y); int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[0]); c[offset] = Op()(a[offset], b[0]);
} }
@@ -65,72 +65,75 @@ template <typename T, typename U, typename Op>
device U* c, device U* c,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y); int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[offset]); c[offset] = Op()(a[offset], b[offset]);
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd1( [[kernel]] void binary_g_nd1(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
constant const size_t& a_stride, constant const int64_t& a_stride,
constant const size_t& b_stride, constant const int64_t& b_stride,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride); auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride); auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]); c[index] = Op()(a[a_idx], b[b_idx]);
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd2( [[kernel]] void binary_g_nd2(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
constant const size_t a_strides[2], constant const int64_t a_strides[2],
constant const size_t b_strides[2], constant const int64_t b_strides[2],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides); auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides); auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y; IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]); c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd3( [[kernel]] void binary_g_nd3(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
constant const size_t a_strides[3], constant const int64_t a_strides[3],
constant const size_t b_strides[3], constant const int64_t b_strides[3],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
size_t out_idx = IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]); c[out_idx] = Op()(a[a_idx], b[b_idx]);
} }
template <typename T, typename U, typename Op, int N = 1> template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = int64_t>
[[kernel]] void binary_g( [[kernel]] void binary_g(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
constant const int* shape, constant const int* shape,
constant const size_t* a_strides, constant const int64_t* a_strides,
constant const size_t* b_strides, constant const int64_t* b_strides,
constant const int& ndim, constant const int& ndim,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd( auto idx = elem_to_loc_2_nd<IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1]; auto xshape = shape[ndim - 1];
size_t out_idx = IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); IdxT a_xstride = a_strides[ndim - 1];
auto a_xstride = a_strides[ndim - 1]; IdxT b_xstride = b_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
c[out_idx++] = Op()(a[idx.x], b[idx.y]); c[out_idx++] = Op()(a[idx.x], b[idx.y]);
idx.x += a_xstride; idx.x += a_xstride;

View File

@@ -9,18 +9,22 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h" #include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \ instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \ instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_integer(op) \ #define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \ instantiate_binary_all(op, uint8, uint8_t, uint8_t) \

View File

@@ -56,7 +56,7 @@ template <typename T, typename U, typename Op>
device U* d, device U* d,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y); auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[0], b[offset]); auto out = Op()(a[0], b[offset]);
c[offset] = out[0]; c[offset] = out[0];
d[offset] = out[1]; d[offset] = out[1];
@@ -70,7 +70,7 @@ template <typename T, typename U, typename Op>
device U* d, device U* d,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y); auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[0]); auto out = Op()(a[offset], b[0]);
c[offset] = out[0]; c[offset] = out[0];
d[offset] = out[1]; d[offset] = out[1];
@@ -84,84 +84,87 @@ template <typename T, typename U, typename Op>
device U* d, device U* d,
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y); auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[offset]); auto out = Op()(a[offset], b[offset]);
c[offset] = out[0]; c[offset] = out[0];
d[offset] = out[1]; d[offset] = out[1];
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd1( [[kernel]] void binary_g_nd1(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
device U* d, device U* d,
constant const size_t& a_stride, constant const int64_t& a_stride,
constant const size_t& b_stride, constant const int64_t& b_stride,
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride); auto a_idx = elem_to_loc_1<IdxT>(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride); auto b_idx = elem_to_loc_1<IdxT>(index, b_stride);
auto out = Op()(a[a_idx], b[b_idx]); auto out = Op()(a[a_idx], b[b_idx]);
c[index] = out[0]; c[index] = out[0];
d[index] = out[1]; d[index] = out[1];
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd2( [[kernel]] void binary_g_nd2(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
device U* d, device U* d,
constant const size_t a_strides[2], constant const int64_t a_strides[2],
constant const size_t b_strides[2], constant const int64_t b_strides[2],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides); auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides); auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
size_t out_idx = index.x + size_t(grid_dim.x) * index.y; IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
auto out = Op()(a[a_idx], b[b_idx]); auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0]; c[out_idx] = out[0];
d[out_idx] = out[1]; d[out_idx] = out[1];
} }
template <typename T, typename U, typename Op> template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd3( [[kernel]] void binary_g_nd3(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
device U* d, device U* d,
constant const size_t a_strides[3], constant const int64_t a_strides[3],
constant const size_t b_strides[3], constant const int64_t b_strides[3],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides); auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides); auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
size_t out_idx = IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
auto out = Op()(a[a_idx], b[b_idx]); auto out = Op()(a[a_idx], b[b_idx]);
c[out_idx] = out[0]; c[out_idx] = out[0];
d[out_idx] = out[1]; d[out_idx] = out[1];
} }
template <typename T, typename U, typename Op, int N = 1> template <
typename T,
typename U,
typename Op,
int N = 1,
typename IdxT = int64_t>
[[kernel]] void binary_g( [[kernel]] void binary_g(
device const T* a, device const T* a,
device const T* b, device const T* b,
device U* c, device U* c,
device U* d, device U* d,
constant const int* shape, constant const int* shape,
constant const size_t* a_strides, constant const int64_t* a_strides,
constant const size_t* b_strides, constant const int64_t* b_strides,
constant const int& ndim, constant const int& ndim,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd( auto idx = elem_to_loc_2_nd<IdxT>(
{N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
auto xshape = shape[ndim - 1]; auto xshape = shape[ndim - 1];
size_t out_idx = IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z); IdxT a_xstride = a_strides[ndim - 1];
auto a_xstride = a_strides[ndim - 1]; IdxT b_xstride = b_strides[ndim - 1];
auto b_xstride = b_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
auto out = Op()(a[idx.x], b[idx.y]); auto out = Op()(a[idx.x], b[idx.y]);
c[out_idx] = out[0]; c[out_idx] = out[0];

View File

@@ -7,18 +7,22 @@
#include "mlx/backend/metal/kernels/binary_ops.h" #include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h" #include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \ #define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \ instantiate_kernel("ss_" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \ instantiate_kernel("sv_" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \ instantiate_kernel("vs_" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \ instantiate_kernel("vv_" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \ instantiate_kernel("sv2_" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \ instantiate_kernel("vs2_" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \ instantiate_kernel("vv2_" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn4_" #op #tname, binary_g, itype, otype, op, 4) \ instantiate_kernel("gn2_" #op #tname, binary_g, itype, otype, op, 2, uint) \
instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op) \ instantiate_kernel("gn4large_" #op #tname, binary_g, itype, otype, op, 4) \
instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op) \ instantiate_kernel("g1_" #op #tname, binary_g_nd1, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op) \ instantiate_kernel("g2_" #op #tname, binary_g_nd2, itype, otype, op, uint) \
instantiate_kernel("g3_" #op #tname, binary_g_nd3, itype, otype, op, uint) \
instantiate_kernel("g1large_" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2large_" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3large_" #op #tname, binary_g_nd3, itype, otype, op)
#define instantiate_binary_float(op) \ #define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \ instantiate_binary_all(op, float16, half, half) \

View File

@@ -4,8 +4,8 @@
#include <metal_simdgroup_matrix> #include <metal_simdgroup_matrix>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/conv/params.h" #include "mlx/backend/metal/kernels/steel/conv/params.h"
#include "mlx/backend/metal/kernels/utils.h"
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const

View File

@@ -22,7 +22,7 @@ template <typename T, typename U>
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y); auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[0]); dst[offset] = static_cast<U>(src[0]);
} }
@@ -32,46 +32,46 @@ template <typename T, typename U>
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y); auto offset = index.x + grid_dim.x * int64_t(index.y);
dst[offset] = static_cast<U>(src[offset]); dst[offset] = static_cast<U>(src[offset]);
} }
template <typename T, typename U> template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd1( [[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& src_stride [[buffer(3)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride); auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
dst[index] = static_cast<U>(src[src_idx]); dst[index] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd2( [[kernel]] void copy_g_nd2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* src_strides [[buffer(3)]],
uint2 index [[thread_position_in_grid]], uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) { uint2 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides); auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y; IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y;
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_g_nd3( [[kernel]] void copy_g_nd3(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* src_strides [[buffer(3)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides); auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
int64_t dst_idx = IdxT dst_idx =
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, int N = 1> template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_g( [[kernel]] void copy_g(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
@@ -80,17 +80,16 @@ template <typename T, typename U, int N = 1>
constant const int& ndim [[buffer(5)]], constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
auto src_idx = elem_to_loc( auto src_idx = elem_to_loc<IdxT>(
{N * index.x, index.y, index.z}, src_shape, src_strides, ndim); {N * index.x, index.y, index.z}, src_shape, src_strides, ndim);
if (N == 1) { if (N == 1) {
int64_t dst_idx = IdxT dst_idx =
index.x + grid_dim.x * (index.y + int64_t(grid_dim.y) * index.z); index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
return; return;
} }
auto xshape = src_shape[ndim - 1]; auto xshape = src_shape[ndim - 1];
int64_t dst_idx = IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
N * index.x + xshape * (index.y + int64_t(grid_dim.y) * index.z);
auto src_xstride = src_strides[ndim - 1]; auto src_xstride = src_strides[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[dst_idx + i] = static_cast<U>(src[src_idx]); dst[dst_idx + i] = static_cast<U>(src[src_idx]);
@@ -98,43 +97,43 @@ template <typename T, typename U, int N = 1>
} }
} }
template <typename T, typename U> template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd1( [[kernel]] void copy_gg_nd1(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t& src_stride [[buffer(3)]], constant const int64_t& src_stride [[buffer(3)]],
constant const int64_t& dst_stride [[buffer(4)]], constant const int64_t& dst_stride [[buffer(4)]],
uint index [[thread_position_in_grid]]) { uint index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_1(index, src_stride); auto src_idx = elem_to_loc_1<IdxT>(index, src_stride);
auto dst_idx = elem_to_loc_1(index, dst_stride); auto dst_idx = elem_to_loc_1<IdxT>(index, dst_stride);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd2( [[kernel]] void copy_gg_nd2(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]], constant const int64_t* dst_strides [[buffer(4)]],
uint2 index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_2(index, src_strides); auto src_idx = elem_to_loc_2<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_2(index, dst_strides); auto dst_idx = elem_to_loc_2<IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U> template <typename T, typename U, typename IdxT = int64_t>
[[kernel]] void copy_gg_nd3( [[kernel]] void copy_gg_nd3(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
constant const int64_t* src_strides [[buffer(3)]], constant const int64_t* src_strides [[buffer(3)]],
constant const int64_t* dst_strides [[buffer(4)]], constant const int64_t* dst_strides [[buffer(4)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto src_idx = elem_to_loc_3(index, src_strides); auto src_idx = elem_to_loc_3<IdxT>(index, src_strides);
auto dst_idx = elem_to_loc_3(index, dst_strides); auto dst_idx = elem_to_loc_3<IdxT>(index, dst_strides);
dst[dst_idx] = static_cast<U>(src[src_idx]); dst[dst_idx] = static_cast<U>(src[src_idx]);
} }
template <typename T, typename U, int N = 1> template <typename T, typename U, int N = 1, typename IdxT = int64_t>
[[kernel]] void copy_gg( [[kernel]] void copy_gg(
device const T* src [[buffer(0)]], device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]], device U* dst [[buffer(1)]],
@@ -143,7 +142,7 @@ template <typename T, typename U, int N = 1>
constant const int64_t* dst_strides [[buffer(4)]], constant const int64_t* dst_strides [[buffer(4)]],
constant const int& ndim [[buffer(5)]], constant const int& ndim [[buffer(5)]],
uint3 index [[thread_position_in_grid]]) { uint3 index [[thread_position_in_grid]]) {
auto idx = elem_to_loc_2_nd( auto idx = elem_to_loc_2_nd<IdxT>(
{N * index.x, index.y, index.z}, {N * index.x, index.y, index.z},
src_shape, src_shape,
src_strides, src_strides,
@@ -153,8 +152,8 @@ template <typename T, typename U, int N = 1>
dst[idx.y] = static_cast<U>(src[idx.x]); dst[idx.y] = static_cast<U>(src[idx.x]);
return; return;
} }
auto src_xstride = src_strides[ndim - 1]; IdxT src_xstride = src_strides[ndim - 1];
auto dst_xstride = dst_strides[ndim - 1]; IdxT dst_xstride = dst_strides[ndim - 1];
auto xshape = src_shape[ndim - 1]; auto xshape = src_shape[ndim - 1];
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
dst[idx.y] = static_cast<U>(src[idx.x]); dst[idx.y] = static_cast<U>(src[idx.x]);

View File

@@ -2,22 +2,29 @@
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h" #include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy_all(tname, itype, otype) \ #define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \ instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \ instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \ instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \ instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype, int) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \ instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype, int) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \ instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype, int) \
instantiate_kernel("gn4_copy" #tname, copy_g, itype, otype, 4) \ instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \
instantiate_kernel("ggn4_copy" #tname, copy_gg, itype, otype, 4) instantiate_kernel("ggn2_copy" #tname, copy_gg, itype, otype, 2, int) \
instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) \
instantiate_kernel("ggn4large_copy" #tname, copy_gg, itype, otype, 4)
#define instantiate_copy_itype(itname, itype) \ #define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \ instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -4,30 +4,30 @@
#include "mlx/backend/metal/kernels/indexing.h" #include "mlx/backend/metal/kernels/indexing.h"
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
METAL_FUNC void gather_impl( METAL_FUNC void gather_impl(
const device T* src [[buffer(0)]], const device T* src [[buffer(0)]],
device T* out [[buffer(1)]], device T* out [[buffer(1)]],
const constant int* src_shape [[buffer(2)]], const constant int* src_shape [[buffer(2)]],
const constant size_t* src_strides [[buffer(3)]], const constant int64_t* src_strides [[buffer(3)]],
const constant size_t& src_ndim [[buffer(4)]], const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]], const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]], const constant int* axes [[buffer(6)]],
const thread Indices<IdxT, NIDX>& indices, const thread Indices<IdxT, NIDX>& indices,
uint3 index [[thread_position_in_grid]], uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) { uint3 grid_dim [[threads_per_grid]]) {
size_t src_idx = 0; LocT src_idx = 0;
for (int i = 0; i < NIDX; ++i) { for (int i = 0; i < NIDX; ++i) {
size_t idx_loc; LocT idx_loc;
if (IDX_NDIM == 0) { if (IDX_NDIM == 0) {
idx_loc = 0; idx_loc = 0;
} else if (IDX_NDIM == 1) { } else if (IDX_NDIM == 1) {
idx_loc = index.x * indices.strides[indices.ndim * i]; idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
} else { } else {
idx_loc = index.x * indices.strides[indices.ndim * i]; idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
idx_loc += indices.row_contiguous[i] idx_loc += indices.row_contiguous[i]
? index.y ? index.y
: elem_to_loc( : elem_to_loc<LocT>(
index.y, index.y,
&indices.shapes[indices.ndim * i + 1], &indices.shapes[indices.ndim * i + 1],
&indices.strides[indices.ndim * i + 1], &indices.strides[indices.ndim * i + 1],
@@ -35,17 +35,17 @@ METAL_FUNC void gather_impl(
} }
auto ax = axes[i]; auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax]; src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
} }
auto src_offset = elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); auto src_offset =
elem_to_loc<LocT>(index.z, slice_sizes, src_strides, src_ndim);
size_t out_idx = index.z; LocT out_idx = index.z;
if (IDX_NDIM == 1) { if (IDX_NDIM == 1) {
out_idx += static_cast<size_t>(grid_dim.z) * index.x; out_idx += static_cast<LocT>(grid_dim.z) * index.x;
} else if (IDX_NDIM >= 2) { } else if (IDX_NDIM >= 2) {
out_idx += out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
grid_dim.z * (index.x * static_cast<size_t>(grid_dim.y) + index.y);
} }
out[out_idx] = src[src_offset + src_idx]; out[out_idx] = src[src_offset + src_idx];
} }

View File

@@ -3,8 +3,6 @@
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/utils.h" #include "mlx/backend/metal/kernels/steel/utils.h"
@@ -438,9 +436,9 @@ template <
const constant float& beta [[buffer(8)]], const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]], const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]], const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]], const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]], const constant int64_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* bias_batch_stride [[buffer(13)]], const constant int64_t* bias_batch_stride [[buffer(13)]],
const constant int& bias_stride [[buffer(14)]], const constant int& bias_stride [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
@@ -488,31 +486,21 @@ template <
simd_lid); simd_lid);
} }
#define instantiate_gemv_helper( \ #define instantiate_gemv_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \ instantiate_kernel( \
"_tm" #tm "_tn" #tn "_nc" #nc \ "gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
"_axpby" #axpby)]] [[kernel]] void \ "_tn" #tn "_nc" #nc "_axpby" #axpby, \
gemv<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \ gemv, \
const device itype* mat [[buffer(0)]], \ itype, \
const device itype* in_vec [[buffer(1)]], \ bm, \
const device itype* bias [[buffer(2)]], \ bn, \
device itype* out_vec [[buffer(3)]], \ sm, \
const constant int& in_vec_size [[buffer(4)]], \ sn, \
const constant int& out_vec_size [[buffer(5)]], \ tm, \
const constant int& marix_ld [[buffer(6)]], \ tn, \
const constant float& alpha [[buffer(7)]], \ nc, \
const constant float& beta [[buffer(8)]], \ axpby)
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off // clang-format off
#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ #define instantiate_gemv(name, itype, bm, bn, tm, tn) \
@@ -551,13 +539,13 @@ template <
const constant float& beta [[buffer(8)]], const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]], const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]], const constant int* batch_shape [[buffer(10)]],
const constant size_t* index_batch_strides [[buffer(11)]], const constant int64_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]], const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]], const constant int* vector_batch_shape [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(14)]], const constant int64_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]], const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]], const constant int* matrix_batch_shape [[buffer(16)]],
const constant size_t* matrix_batch_stride [[buffer(17)]], const constant int64_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]], const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]], const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -573,8 +561,8 @@ template <
// Update batch offsets // Update batch offsets
if (batch_ndim > 1) { if (batch_ndim > 1) {
const constant size_t* veci_bstrides = index_batch_strides; const constant auto* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim; const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast( ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
@@ -621,37 +609,14 @@ template <
simd_lid); simd_lid);
} }
#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
template [[host_name("gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
gemv_gather<itype, bm, bn, sm, sn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off // clang-format off
#define instantiate_gemv_bs_blocks(name, itype) \ #define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
instantiate_kernel( \
"gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn, \
gemv_gather, itype, bm, bn, sm, sn, tm, tn)
#define instantiate_gemv_bs_blocks(name, itype) \
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \ instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \
instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \ instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \
instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
@@ -686,9 +651,9 @@ template <
const constant float& beta [[buffer(8)]], const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]], const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]], const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]], const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]], const constant int64_t* matrix_batch_stride [[buffer(12)]],
const constant size_t* bias_batch_stride [[buffer(13)]], const constant int64_t* bias_batch_stride [[buffer(13)]],
const constant int& bias_stride [[buffer(14)]], const constant int& bias_stride [[buffer(14)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
@@ -736,33 +701,14 @@ template <
simd_lid); simd_lid);
} }
#define instantiate_gemv_t_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
"_tm" #tm "_tn" #tn "_nc" #nc \
"_axpby" #axpby)]] [[kernel]] void \
gemv_t<itype, bm, bn, sm, sn, tm, tn, nc, axpby>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const constant size_t* bias_batch_stride [[buffer(13)]], \
const constant int& bias_stride [[buffer(14)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off // clang-format off
#define instantiate_gemv_t_helper( \
name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
instantiate_kernel( \
"gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
"_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \
gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby)
#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \ #define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
@@ -802,13 +748,13 @@ template <
const constant float& beta [[buffer(8)]], const constant float& beta [[buffer(8)]],
const constant int& batch_ndim [[buffer(9)]], const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]], const constant int* batch_shape [[buffer(10)]],
const constant size_t* index_batch_strides [[buffer(11)]], const constant int64_t* index_batch_strides [[buffer(11)]],
const constant int& vector_batch_ndim [[buffer(12)]], const constant int& vector_batch_ndim [[buffer(12)]],
const constant int* vector_batch_shape [[buffer(13)]], const constant int* vector_batch_shape [[buffer(13)]],
const constant size_t* vector_batch_stride [[buffer(14)]], const constant int64_t* vector_batch_stride [[buffer(14)]],
const constant int& matrix_batch_ndim [[buffer(15)]], const constant int& matrix_batch_ndim [[buffer(15)]],
const constant int* matrix_batch_shape [[buffer(16)]], const constant int* matrix_batch_shape [[buffer(16)]],
const constant size_t* matrix_batch_stride [[buffer(17)]], const constant int64_t* matrix_batch_stride [[buffer(17)]],
const constant uint32_t* vec_indices [[buffer(18)]], const constant uint32_t* vec_indices [[buffer(18)]],
const constant uint32_t* mat_indices [[buffer(19)]], const constant uint32_t* mat_indices [[buffer(19)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
@@ -824,8 +770,8 @@ template <
// Update batch offsets // Update batch offsets
if (batch_ndim > 1) { if (batch_ndim > 1) {
const constant size_t* veci_bstrides = index_batch_strides; const constant auto* veci_bstrides = index_batch_strides;
const constant size_t* mati_bstrides = index_batch_strides + batch_ndim; const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast( ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
@@ -872,36 +818,14 @@ template <
simd_lid); simd_lid);
} }
#define instantiate_gemv_t_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
template [[host_name("gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn)]] [[kernel]] void \
gemv_t_gather<itype, bm, bn, sm, sn, tm, tn>( \
const device itype* mat [[buffer(0)]], \
const device itype* in_vec [[buffer(1)]], \
const device itype* bias [[buffer(2)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant float& alpha [[buffer(7)]], \
const constant float& beta [[buffer(8)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* index_batch_strides [[buffer(11)]], \
const constant int& vector_batch_ndim [[buffer(12)]], \
const constant int* vector_batch_shape [[buffer(13)]], \
const constant size_t* vector_batch_stride [[buffer(14)]], \
const constant int& matrix_batch_ndim [[buffer(15)]], \
const constant int* matrix_batch_shape [[buffer(16)]], \
const constant size_t* matrix_batch_stride [[buffer(17)]], \
const constant uint32_t* vec_indices [[buffer(18)]], \
const constant uint32_t* mat_indices [[buffer(19)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
// clang-format off // clang-format off
#define instantiate_gemv_t_bs_helper( \
nm, itype, bm, bn, sm, sn, tm, tn) \
instantiate_kernel( \
"gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
"_sn" #sn "_tm" #tm "_tn" #tn, \
gemv_t_gather, itype, bm, bn, sm, sn, tm, tn)
#define instantiate_gemv_t_bs_blocks(name, itype) \ #define instantiate_gemv_t_bs_blocks(name, itype) \
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \ instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \
instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \ instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \
@@ -912,4 +836,4 @@ template <
// clang-format off // clang-format off
instantiate_gemv_t_bs_blocks(float32, float); instantiate_gemv_t_bs_blocks(float32, float);
instantiate_gemv_t_bs_blocks(float16, half); instantiate_gemv_t_bs_blocks(float16, half);
instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); // clang-format on

View File

@@ -642,13 +642,13 @@ template <
const constant int& marix_ld [[buffer(6)]], const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]], const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]], const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]], const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]], const constant int64_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]], const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]], const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]], const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]], const constant int64_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -673,8 +673,8 @@ template <
} }
if (has_operand_mask) { if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides; const constant auto* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast( ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);
@@ -742,13 +742,13 @@ template <
const constant int& marix_ld [[buffer(6)]], const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]], const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]], const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]], const constant int64_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]], const constant int64_t* matrix_batch_stride [[buffer(12)]],
const device out_mask_t* out_mask [[buffer(20)]], const device out_mask_t* out_mask [[buffer(20)]],
const device op_mask_t* mat_mask [[buffer(21)]], const device op_mask_t* mat_mask [[buffer(21)]],
const device op_mask_t* vec_mask [[buffer(22)]], const device op_mask_t* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]], const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]], const constant int64_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -773,8 +773,8 @@ template <
} }
if (has_operand_mask) { if (has_operand_mask) {
const constant size_t* mask_strides_mat = mask_batch_strides; const constant auto* mask_strides_mat = mask_batch_strides;
const constant size_t* mask_strides_vec = mask_strides_mat + batch_ndim; const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast( ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim);

View File

@@ -4,37 +4,17 @@
#include <metal_simdgroup> #include <metal_simdgroup>
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/gemv_masked.h" #include "mlx/backend/metal/kernels/gemv_masked.h"
#define instantiate_gemv_helper( \ #define instantiate_gemv_helper( \
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
template [[host_name("gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ instantiate_kernel( \
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ "gemv_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
"_tn" #tn "_nc" #nc)]] [[kernel]] void \ "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
gemv_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \ "_tn" #tn "_nc" #nc, \
const device itype* mat [[buffer(0)]], \ gemv_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
const device itype* in_vec [[buffer(1)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const device outm_t* out_mask [[buffer(20)]], \
const device opm_t* mat_mask [[buffer(21)]], \
const device opm_t* vec_mask [[buffer(22)]], \
const constant int* mask_strides [[buffer(23)]], \
const constant size_t* mask_batch_strides [[buffer(24)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ #define instantiate_gemv_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \
@@ -63,29 +43,11 @@ instantiate_gemv_blocks(bfloat16, bfloat16_t);
#define instantiate_gemv_t_helper( \ #define instantiate_gemv_t_helper( \
outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \ outm_n, outm_t, opm_n, opm_t, name, itype, bm, bn, sm, sn, tm, tn, nc) \
template [[host_name("gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \ instantiate_kernel( \
"_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ "gemv_t_outmask_" #outm_n "_opmask_" #opm_n "_" #name \
"_tn" #tn "_nc" #nc)]] [[kernel]] void \ "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
gemv_t_masked<itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc>( \ "_tn" #tn "_nc" #nc, \
const device itype* mat [[buffer(0)]], \ gemv_t_masked, itype, outm_t, opm_t, bm, bn, sm, sn, tm, tn, nc)
const device itype* in_vec [[buffer(1)]], \
device itype* out_vec [[buffer(3)]], \
const constant int& in_vec_size [[buffer(4)]], \
const constant int& out_vec_size [[buffer(5)]], \
const constant int& marix_ld [[buffer(6)]], \
const constant int& batch_ndim [[buffer(9)]], \
const constant int* batch_shape [[buffer(10)]], \
const constant size_t* vector_batch_stride [[buffer(11)]], \
const constant size_t* matrix_batch_stride [[buffer(12)]], \
const device outm_t* out_mask [[buffer(20)]], \
const device opm_t* mat_mask [[buffer(21)]], \
const device opm_t* vec_mask [[buffer(22)]], \
const constant int* mask_strides [[buffer(23)]], \
const constant size_t* mask_batch_strides [[buffer(24)]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \ #define instantiate_gemv_t_base(name, itype, bm, bn, sm, sn, tm, tn, nc) \
instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \ instantiate_gemv_t_helper(bool_, bool, bool_, bool, name, itype, bm, bn, sm, sn, tm, tn, nc) \

View File

@@ -8,13 +8,13 @@ template <typename IdxT, int NIDX>
struct Indices { struct Indices {
const array<const device IdxT*, NIDX> buffers; const array<const device IdxT*, NIDX> buffers;
const constant int* shapes; const constant int* shapes;
const constant size_t* strides; const constant int64_t* strides;
const constant bool* row_contiguous; const constant bool* row_contiguous;
const int ndim; const int ndim;
}; };
template <typename IdxT> template <typename IdxT>
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
if (is_unsigned_v<IdxT>) { if (is_unsigned_v<IdxT>) {
return idx; return idx;
} else { } else {

View File

@@ -0,0 +1,16 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#define jit_if #if
#define jit_else #else
#define jit_endif #endif
jit_if (__METAL_VERSION__ >= 310)
#include "mlx/backend/metal/kernels/metal_3_1/bf16.h"
jit_else
#include "mlx/backend/metal/kernels/metal_3_0/bf16.h"
jit_endif // clang-format on

View File

@@ -3,8 +3,6 @@
#include <metal_common> #include <metal_common>
#include <metal_simdgroup> #include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
using namespace metal; using namespace metal;

View File

@@ -6,12 +6,6 @@
using namespace metal; using namespace metal;
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;
#else
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Helpers // Helpers
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@@ -311,7 +305,10 @@ METAL_FUNC bool isnan(_MLX_BFloat16 x) {
} // namespace metal } // namespace metal
#pragma METAL internals : disable #pragma METAL internals : disable
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return x.bits_;
}
#endif inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat());
#include "mlx/backend/metal/kernels/bf16_math.h" }

View File

@@ -0,0 +1,16 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
typedef bfloat bfloat16_t;
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
return as_type<uint16_t>(x);
}
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
return as_type<bfloat16_t>(x);
}

File diff suppressed because it is too large Load Diff

View File

@@ -72,7 +72,6 @@
#define instantiate_quantized_all_single(type, group_size, bits) \ #define instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized(affine_quantize, type, group_size, bits) \ instantiate_quantized(affine_quantize, type, group_size, bits) \
instantiate_quantized(affine_quantize_scales_biases, type, group_size, bits) \
instantiate_quantized(affine_dequantize, type, group_size, bits) \ instantiate_quantized(affine_dequantize, type, group_size, bits) \
instantiate_quantized(bs_qmv_fast, type, group_size, bits) \ instantiate_quantized(bs_qmv_fast, type, group_size, bits) \
instantiate_quantized(bs_qmv, type, group_size, bits) \ instantiate_quantized(bs_qmv, type, group_size, bits) \
@@ -116,7 +115,9 @@
#define instantiate_quantized_all() \ #define instantiate_quantized_all() \
instantiate_quantized_groups(2) \ instantiate_quantized_groups(2) \
instantiate_quantized_groups(3) \
instantiate_quantized_groups(4) \ instantiate_quantized_groups(4) \
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8) instantiate_quantized_groups(8)
instantiate_quantized_all() // clang-format on instantiate_quantized_all() // clang-format on

View File

@@ -71,7 +71,7 @@ rbits threefry2x32_hash(const thread uint2& key, uint2 count) {
constant const uint& bytes_per_key, constant const uint& bytes_per_key,
constant const int& ndim, constant const int& ndim,
constant const int* key_shape, constant const int* key_shape,
constant const size_t* key_strides, constant const int64_t* key_strides,
uint2 grid_dim [[threads_per_grid]], uint2 grid_dim [[threads_per_grid]],
uint2 index [[thread_position_in_grid]]) { uint2 index [[thread_position_in_grid]]) {
auto kidx = 2 * index.x; auto kidx = 2 * index.x;

View File

@@ -10,186 +10,156 @@
#include "mlx/backend/metal/kernels/reduction/ops.h" #include "mlx/backend/metal/kernels/reduction/ops.h"
#include "mlx/backend/metal/kernels/reduce.h" #include "mlx/backend/metal/kernels/reduce.h"
#define instantiate_reduce_helper_floats(inst_f, name, op) \ #define instantiate_init_reduce(name, tname, type, op) \
inst_f(name, float16, half, op) \ instantiate_kernel("init_reduce_" #name #tname, init_reduce, type, op<type>)
inst_f(name, float32, float, op) \
inst_f(name, bfloat16, bfloat16_t, op)
#define instantiate_reduce_helper_uints(inst_f, name, op) \ instantiate_init_reduce(and, bool_, bool, And)
inst_f(name, uint8, uint8_t, op) \ instantiate_init_reduce(or, bool_, bool, Or)
inst_f(name, uint16, uint16_t, op) \
inst_f(name, uint32, uint32_t, op)
#define instantiate_reduce_helper_ints(inst_f, name, op) \ #define instantiate_init_sum_prod(name, op) \
inst_f(name, int8, int8_t, op) \ instantiate_init_reduce(name, int32, int32_t, op) \
inst_f(name, int16, int16_t, op) \ instantiate_init_reduce(name, int64, int64_t, op) \
inst_f(name, int32, int32_t, op) instantiate_init_reduce(name, float16, float16_t, op) \
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
instantiate_init_reduce(name, float32, float, op) \
instantiate_init_reduce(name, complex64, complex64_t, op)
#define instantiate_reduce_helper_64b(inst_f, name, op) \ instantiate_init_sum_prod(sum, Sum)
inst_f(name, int64, int64_t, op) \ instantiate_init_sum_prod(prod, Prod)
inst_f(name, uint64, uint64_t, op) \
inst_f(name, complex64, complex64_t, op)
#define instantiate_reduce_helper_types(inst_f, name, op) \ #define instantiate_init_min_max(name, op) \
instantiate_reduce_helper_floats(inst_f, name, op) \ instantiate_init_reduce(name, bool_, bool, op) \
instantiate_reduce_helper_uints(inst_f, name, op) \ instantiate_init_reduce(name, int8, int8_t, op) \
instantiate_reduce_helper_ints(inst_f, name, op) instantiate_init_reduce(name, int16, int16_t, op) \
instantiate_init_reduce(name, int32, int32_t, op) \
instantiate_init_reduce(name, int64, int64_t, op) \
instantiate_init_reduce(name, uint8, uint8_t, op) \
instantiate_init_reduce(name, uint16, uint16_t, op) \
instantiate_init_reduce(name, uint32, uint32_t, op) \
instantiate_init_reduce(name, uint64, uint64_t, op) \
instantiate_init_reduce(name, float16, float16_t, op) \
instantiate_init_reduce(name, bfloat16, bfloat16_t, op) \
instantiate_init_reduce(name, float32, float, op) \
instantiate_init_reduce(name, complex64, complex64_t, op)
#define instantiate_reduce_ops(inst_f, type_f) \ instantiate_init_min_max(min, Min)
type_f(inst_f, sum, Sum) \ instantiate_init_min_max(max, Max)
type_f(inst_f, prod, Prod) \
type_f(inst_f, min, Min) \
type_f(inst_f, max, Max)
// Special case for bool reductions
#define instantiate_reduce_from_types_helper( \
inst_f, name, tname, itype, otype, op) \
inst_f(name##tname, itype, otype, op)
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, bool_, bool, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint8, uint8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint16, uint16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint32, uint32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, uint64, uint64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int8, int8_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int16, int16_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int32, int32_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, int64, int64_t, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, name, float16, half, otype, op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
float32, \
float, \
otype, \
op) \
instantiate_reduce_from_types_helper( \
inst_f, \
name, \
bfloat16, \
bfloat16_t, \
otype, \
op)
#define instantiate_init_reduce(name, otype, op) \
instantiate_kernel("init_reduce_" #name, \
init_reduce, \
otype, op)
#define instantiate_init_reduce_helper(name, tname, type, op) \
instantiate_init_reduce(name##tname, type, op<type>)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
instantiate_init_reduce(andbool_, bool, And<bool>)
instantiate_init_reduce(orbool_, bool, Or<bool>)
#define instantiate_all_reduce(name, itype, otype, op) \ #define instantiate_all_reduce(name, itype, otype, op) \
instantiate_kernel("all_reduce_" #name, \ instantiate_kernel("all_reduce_" #name, \
all_reduce, \ all_reduce, \
itype, otype, op) itype, otype, op)
#define instantiate_same_all_reduce_helper(name, tname, type, op) \ #define instantiate_col_reduce_small(name, itype, otype, op, dim) \
instantiate_all_reduce(name##tname, type, type, op<type>) instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, uint, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, uint, dim) \
instantiate_kernel("col_reduce_small_large_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, int64_t, dim) \
instantiate_kernel("col_reduce_longcolumn_large_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, int64_t, dim)
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types) #define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_64b) instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, uint, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, int64_t, dim, bm, bn)
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>) #define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>) instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
// special case bool with larger output type itype, otype, op, uint, dim, bm, bn) \
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) instantiate_kernel("col_reduce_2pass_large_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
#define instantiate_col_reduce_small(name, itype, otype, op, dim) \ itype, otype, op, int64_t, dim, bm, bn)
instantiate_kernel("col_reduce_small_" #dim "_reduce_" #name, \
col_reduce_small, \
itype, otype, op, dim) \
instantiate_kernel("col_reduce_longcolumn_" #dim "_reduce_" #name, \
col_reduce_longcolumn, \
itype, otype, op, dim)
#define instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_looped_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_looped, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, bm, bn) \
instantiate_kernel("col_reduce_2pass_" #dim "_" #bm "_" #bn "_reduce_" #name, \
col_reduce_2pass, \
itype, otype, op, dim, bm, bn)
#define instantiate_col_reduce_looped(name, itype, otype, op, dim) \ #define instantiate_col_reduce_looped(name, itype, otype, op, dim) \
instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \ instantiate_col_reduce_looped_tile(name, itype, otype, op, dim, 32, 32) \
instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32) instantiate_col_reduce_2pass_tile(name, itype, otype, op, dim, 32, 32)
#define instantiate_col_reduce_general(name, itype, otype, op) \ #define instantiate_col_reduce_general(name, itype, otype, op) \
instantiate_col_reduce_small(name, itype, otype, op, 0) \
instantiate_col_reduce_small(name, itype, otype, op, 1) \ instantiate_col_reduce_small(name, itype, otype, op, 1) \
instantiate_col_reduce_small(name, itype, otype, op, 2) \ instantiate_col_reduce_small(name, itype, otype, op, 2) \
instantiate_col_reduce_small(name, itype, otype, op, 3) \ instantiate_col_reduce_small(name, itype, otype, op, 5) \
instantiate_col_reduce_small(name, itype, otype, op, 4) \
instantiate_col_reduce_looped(name, itype, otype, op, 0) \
instantiate_col_reduce_looped(name, itype, otype, op, 1) \ instantiate_col_reduce_looped(name, itype, otype, op, 1) \
instantiate_col_reduce_looped(name, itype, otype, op, 2) \ instantiate_col_reduce_looped(name, itype, otype, op, 2) \
instantiate_col_reduce_looped(name, itype, otype, op, 3) \ instantiate_col_reduce_looped(name, itype, otype, op, 5)
instantiate_col_reduce_looped(name, itype, otype, op, 4)
#define instantiate_same_col_reduce_helper(name, tname, type, op) \ #define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_col_reduce_general(name##tname, type, type, op<type>) instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, uint, dim) \
instantiate_kernel("row_reduce_small_large_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, int64_t, dim)
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types) #define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_64b) instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) itype, otype, op, uint, dim) \
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>) instantiate_kernel("row_reduce_looped_large_" #dim "_reduce_" #name, \
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>) row_reduce_looped, \
itype, otype, op, int64_t, dim)
#define instantiate_row_reduce_small(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_small_" #dim "_reduce_" #name, \
row_reduce_small, \
itype, otype, op, dim)
#define instantiate_row_reduce_looped(name, itype, otype, op, dim) \
instantiate_kernel("row_reduce_looped_" #dim "_reduce_" #name, \
row_reduce_looped, \
itype, otype, op, dim)
#define instantiate_row_reduce_general(name, itype, otype, op) \ #define instantiate_row_reduce_general(name, itype, otype, op) \
instantiate_row_reduce_small(name, itype, otype, op, 0) \
instantiate_row_reduce_small(name, itype, otype, op, 1) \ instantiate_row_reduce_small(name, itype, otype, op, 1) \
instantiate_row_reduce_small(name, itype, otype, op, 2) \ instantiate_row_reduce_small(name, itype, otype, op, 2) \
instantiate_row_reduce_small(name, itype, otype, op, 3) \ instantiate_row_reduce_small(name, itype, otype, op, 5) \
instantiate_row_reduce_small(name, itype, otype, op, 4) \
instantiate_row_reduce_looped(name, itype, otype, op, 0) \
instantiate_row_reduce_looped(name, itype, otype, op, 1) \ instantiate_row_reduce_looped(name, itype, otype, op, 1) \
instantiate_row_reduce_looped(name, itype, otype, op, 2) \ instantiate_row_reduce_looped(name, itype, otype, op, 2) \
instantiate_row_reduce_looped(name, itype, otype, op, 3) \ instantiate_row_reduce_looped(name, itype, otype, op, 5) \
instantiate_row_reduce_looped(name, itype, otype, op, 4) \
instantiate_kernel("row_reduce_simple_" #name, \ instantiate_kernel("row_reduce_simple_" #name, \
row_reduce_simple, \ row_reduce_simple, \
itype, otype, op) itype, otype, op)
#define instantiate_same_row_reduce_helper(name, tname, type, op) \ #define instantiate_reduce_functions(name, tname, itype, otype, op) \
instantiate_row_reduce_general(name##tname, type, type, op<type>) instantiate_all_reduce(name##tname, itype, otype, op<otype>) \
instantiate_row_reduce_general(name##tname, itype, otype, op<otype>) \
instantiate_col_reduce_general(name##tname, itype, otype, op<otype>)
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types) #define instantiate_and_or(name, op) \
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_64b) instantiate_reduce_functions(name, bool_, bool, bool, op) \
instantiate_reduce_functions(name, int16, int16_t, bool, op) \
instantiate_reduce_functions(name, int32, int32_t, bool, op) \
instantiate_reduce_functions(name, int64, int64_t, bool, op)
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>) instantiate_and_or(and, And)
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>) instantiate_and_or(or, Or)
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>) #define instantiate_sum_prod(name, op) \
instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
instantiate_reduce_functions(name, float32, float, float, op) \
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
instantiate_sum_prod(sum, Sum)
instantiate_sum_prod(prod, Prod)
#define instantiate_min_max(name, op) \
instantiate_reduce_functions(name, int8, int8_t, int8_t, op) \
instantiate_reduce_functions(name, int16, int16_t, int16_t, op) \
instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \
instantiate_reduce_functions(name, int64, int64_t, int64_t, op) \
instantiate_reduce_functions(name, uint8, uint8_t, uint8_t, op) \
instantiate_reduce_functions(name, uint16, uint16_t, uint16_t, op) \
instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \
instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \
instantiate_reduce_functions(name, float16, float16_t, float16_t, op) \
instantiate_reduce_functions(name, bfloat16, bfloat16_t, bfloat16_t, op) \
instantiate_reduce_functions(name, float32, float, float, op) \
instantiate_reduce_functions(name, complex64, complex64_t, complex64_t, op)
instantiate_min_max(min, Min)
instantiate_min_max(max, Max)
// clang-format on // clang-format on

View File

@@ -1,6 +1,11 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS> template <
typename T,
typename U,
typename Op,
typename IdxT = int64_t,
int N_READS = REDUCE_N_READS>
[[kernel]] void all_reduce( [[kernel]] void all_reduce(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
@@ -16,10 +21,10 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
threadgroup U shared_vals[simd_size]; threadgroup U shared_vals[simd_size];
U total = Op::init; U total = Op::init;
int64_t start_idx = gid.y * row_size; IdxT start_idx = gid.y * IdxT(row_size);
int64_t actual_row = IdxT actual_row =
(start_idx + row_size <= in_size) ? row_size : in_size - start_idx; (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
int64_t blocks = actual_row / (lsize.x * N_READS); IdxT blocks = actual_row / (lsize.x * N_READS);
int extra = actual_row - blocks * (lsize.x * N_READS); int extra = actual_row - blocks * (lsize.x * N_READS);
extra -= lid.x * N_READS; extra -= lid.x * N_READS;
start_idx += lid.x * N_READS; start_idx += lid.x * N_READS;
@@ -30,7 +35,7 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
extra = 0; extra = 0;
} }
for (int64_t b = 0; b < blocks; b++) { for (IdxT b = 0; b < blocks; b++) {
for (int i = 0; i < N_READS; i++) { for (int i = 0; i < N_READS; i++) {
total = op(static_cast<U>(in[i]), total); total = op(static_cast<U>(in[i]), total);
} }

View File

@@ -1,16 +1,16 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
template <typename T, typename U, typename Op, int NDIMS> template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
[[kernel]] void col_reduce_small( [[kernel]] void col_reduce_small(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]], const constant int64_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]], const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]], const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]], const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]], const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]], const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]], const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]], const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]], uint3 gid [[threadgroup_position_in_grid]],
@@ -19,7 +19,7 @@ template <typename T, typename U, typename Op, int NDIMS>
uint3 lsize [[threads_per_threadgroup]]) { uint3 lsize [[threads_per_threadgroup]]) {
constexpr int n_reads = 4; constexpr int n_reads = 4;
Op op; Op op;
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row; const device T* row;
U totals[n_reads]; U totals[n_reads];
@@ -27,20 +27,20 @@ template <typename T, typename U, typename Op, int NDIMS>
totals[i] = Op::init; totals[i] = Op::init;
} }
size_t column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads; IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads;
if (column >= reduction_stride) { if (column >= reduction_stride) {
return; return;
} }
bool safe = column + n_reads <= reduction_stride; bool safe = column + n_reads <= reduction_stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z); IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
in += in_idx + column; in += in_idx + column;
size_t total_rows = non_col_reductions * reduction_size; IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
loop.next(lid.y, reduce_shape, reduce_strides); loop.next(lid.y, reduce_shape, reduce_strides);
for (size_t r = lid.y; r < total_rows; r += lsize.y) { for (IdxT r = lid.y; r < total_rows; r += lsize.y) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
totals[i] = op(static_cast<U>(row[i]), totals[i]); totals[i] = op(static_cast<U>(row[i]), totals[i]);
@@ -80,7 +80,7 @@ template <typename T, typename U, typename Op, int NDIMS>
} }
if (lid.y == 0) { if (lid.y == 0) {
out += out_idx * reduction_stride + column; out += out_idx * IdxT(reduction_stride) + column;
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
out[i] = totals[i]; out[i] = totals[i];
@@ -93,17 +93,17 @@ template <typename T, typename U, typename Op, int NDIMS>
} }
} }
template <typename T, typename U, typename Op, int NDIMS> template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
[[kernel]] void col_reduce_longcolumn( [[kernel]] void col_reduce_longcolumn(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]], const constant size_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]], const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]], const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]], const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]], const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]], const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]], const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]], const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]], const constant size_t& out_size [[buffer(11)]],
@@ -112,19 +112,19 @@ template <typename T, typename U, typename Op, int NDIMS>
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
uint3 lsize [[threads_per_threadgroup]]) { uint3 lsize [[threads_per_threadgroup]]) {
Op op; Op op;
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row; const device T* row;
size_t out_idx = gid.x + gsize.x * size_t(gid.y); IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
in += in_idx + lid.x; in += in_idx + lid.x;
U total = Op::init; U total = Op::init;
size_t total_rows = non_col_reductions * reduction_size; IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
for (size_t r = gid.z * lsize.y + lid.y; r < total_rows; for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows;
r += lsize.y * gsize.z) { r += lsize.y * gsize.z) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
total = op(static_cast<U>(*row), total); total = op(static_cast<U>(*row), total);
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
} }
@@ -136,7 +136,8 @@ template <typename T, typename U, typename Op, int NDIMS>
for (uint i = 1; i < lsize.y; i++) { for (uint i = 1; i < lsize.y; i++) {
total = op(total, shared_vals[i * lsize.x + lid.x]); total = op(total, shared_vals[i * lsize.x + lid.x]);
} }
out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total; out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] =
total;
} }
} }
@@ -151,17 +152,24 @@ template <typename T, typename U, typename Op, int NDIMS>
* totals with a loop. * totals with a loop.
* 7. Write them to the output * 7. Write them to the output
*/ */
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN> template <
typename T,
typename U,
typename Op,
typename IdxT,
int NDIMS,
int BM,
int BN>
[[kernel]] void col_reduce_looped( [[kernel]] void col_reduce_looped(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]], const constant int64_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]], const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]], const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]], const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]], const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]], const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]], const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]], const constant size_t& non_col_reductions [[buffer(10)]],
uint3 gid [[threadgroup_position_in_grid]], uint3 gid [[threadgroup_position_in_grid]],
@@ -176,7 +184,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
threadgroup U shared_vals[BN * BM]; threadgroup U shared_vals[BN * BM];
U totals[n_reads]; U totals[n_reads];
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row; const device T* row;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@@ -185,17 +193,17 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
short lid = simd_group_id * simd_size + simd_lane_id; short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x; IdxT column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride; bool safe = column + n_reads <= reduction_stride;
size_t out_idx = gid.y + gsize.y * size_t(gid.z); IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
in += in_idx + column; in += in_idx + column;
size_t total = non_col_reductions * reduction_size; IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
loop.next(offset.y, reduce_shape, reduce_strides); loop.next(offset.y, reduce_shape, reduce_strides);
for (size_t r = offset.y; r < total; r += BM) { for (IdxT r = offset.y; r < total; r += BM) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@@ -235,8 +243,8 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
// Write the output. // Write the output.
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
size_t out_column = BN * gid.x + out_offset.x; IdxT out_column = BN * gid.x + out_offset.x;
out += out_idx * reduction_stride + out_column; out += out_idx * IdxT(reduction_stride) + out_column;
if (out_column + n_outputs <= reduction_stride) { if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) { for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i]; out[i] = totals[i];
@@ -269,7 +277,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
// Write the output. // Write the output.
if (offset.y == 0) { if (offset.y == 0) {
out += out_idx * reduction_stride + column; out += out_idx * IdxT(reduction_stride) + column;
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
out[i] = totals[i]; out[i] = totals[i];
@@ -283,17 +291,24 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
} }
} }
template <typename T, typename U, typename Op, int NDIMS, int BM, int BN> template <
typename T,
typename U,
typename Op,
typename IdxT,
int NDIMS,
int BM,
int BN>
[[kernel]] void col_reduce_2pass( [[kernel]] void col_reduce_2pass(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& reduction_stride [[buffer(3)]], const constant int64_t& reduction_stride [[buffer(3)]],
const constant int* shape [[buffer(4)]], const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]], const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]], const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]], const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]], const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]], const constant int& reduce_ndim [[buffer(9)]],
const constant size_t& non_col_reductions [[buffer(10)]], const constant size_t& non_col_reductions [[buffer(10)]],
const constant size_t& out_size [[buffer(11)]], const constant size_t& out_size [[buffer(11)]],
@@ -312,7 +327,7 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
threadgroup U shared_vals[BN * BM]; threadgroup U shared_vals[BN * BM];
U totals[n_reads]; U totals[n_reads];
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row; const device T* row;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@@ -321,20 +336,19 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
short lid = simd_group_id * simd_size + simd_lane_id; short lid = simd_group_id * simd_size + simd_lane_id;
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
size_t column = BN * gid.x + offset.x; IdxT column = BN * gid.x + offset.x;
bool safe = column + n_reads <= reduction_stride; bool safe = column + n_reads <= reduction_stride;
size_t full_idx = gid.y + gsize.y * size_t(gid.z); IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
size_t block_idx = full_idx / out_size; IdxT block_idx = full_idx / IdxT(out_size);
size_t out_idx = full_idx % out_size; IdxT out_idx = full_idx % IdxT(out_size);
size_t in_idx = elem_to_loc(out_idx, shape, strides, ndim); IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
in += in_idx + column; in += in_idx + column;
size_t total = non_col_reductions * reduction_size; IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
for (size_t r = offset.y + block_idx * BM; r < total; for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) {
r += outer_blocks * BM) { row = in + loop.location();
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim);
if (safe) { if (safe) {
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
@@ -369,8 +383,8 @@ template <typename T, typename U, typename Op, int NDIMS, int BM, int BN>
// Write the output. // Write the output.
if (simd_lane_id == 0) { if (simd_lane_id == 0) {
size_t out_column = BN * gid.x + out_offset.x; IdxT out_column = BN * gid.x + out_offset.x;
out += full_idx * reduction_stride + out_column; out += full_idx * IdxT(reduction_stride) + out_column;
if (out_column + n_outputs <= reduction_stride) { if (out_column + n_outputs <= reduction_stride) {
for (int i = 0; i < n_outputs; i++) { for (int i = 0; i < n_outputs; i++) {
out[i] = totals[i]; out[i] = totals[i];

View File

@@ -98,11 +98,11 @@ template <
METAL_FUNC void per_thread_row_reduce( METAL_FUNC void per_thread_row_reduce(
thread U totals[N_WRITES], thread U totals[N_WRITES],
const device T* in, const device T* in,
const size_t row_idx, const int64_t row_idx,
int blocks, int blocks,
int extra, int extra,
const constant int* shape, const constant int* shape,
const constant size_t* strides, const constant int64_t* strides,
const constant int& ndim, const constant int& ndim,
uint lsize_x, uint lsize_x,
uint lid_x) { uint lid_x) {
@@ -193,18 +193,19 @@ template <
typename T, typename T,
typename U, typename U,
typename Op, typename Op,
typename IdxT,
int NDIMS, int NDIMS,
int N_READS = REDUCE_N_READS> int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_small( [[kernel]] void row_reduce_small(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant size_t& row_size [[buffer(2)]], const constant int64_t& row_size [[buffer(2)]],
const constant size_t& non_row_reductions [[buffer(3)]], const constant int64_t& non_row_reductions [[buffer(3)]],
const constant int* shape [[buffer(4)]], const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]], const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]], const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]], const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]], const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]], const constant int& reduce_ndim [[buffer(9)]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint3 gid [[threadgroup_position_in_grid]], uint3 gid [[threadgroup_position_in_grid]],
@@ -214,20 +215,20 @@ template <
Op op; Op op;
U total_val = Op::init; U total_val = Op::init;
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
// Precompute some row reduction numbers // Precompute some row reduction numbers
const device T* row; const device T* row;
int blocks = row_size / N_READS; int blocks = IdxT(row_size) / N_READS;
int extra = row_size % N_READS; int extra = IdxT(row_size) % N_READS;
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
// Simple loop over non_row_reductions and reduce the row in the thread. // Simple loop over non_row_reductions and reduce the row in the thread.
size_t out_idx = tid.x + tsize.y * size_t(tid.y); IdxT out_idx = tid.x + tsize.y * IdxT(tid.y);
in += elem_to_loc(out_idx, shape, strides, ndim); in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
for (uint r = 0; r < non_row_reductions; r++) { for (uint r = 0; r < non_row_reductions; r++) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra); thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(reduce_shape, reduce_strides); loop.next(reduce_shape, reduce_strides);
} }
@@ -236,13 +237,13 @@ template <
} else { } else {
// Collaboratively reduce over non_row_reductions in the simdgroup. Each // Collaboratively reduce over non_row_reductions in the simdgroup. Each
// thread reduces every 32nd row and then a simple simd reduce. // thread reduces every 32nd row and then a simple simd reduce.
size_t out_idx = gid.y + gsize.y * size_t(gid.z); IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
in += elem_to_loc(out_idx, shape, strides, ndim); in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
loop.next(simd_lane_id, reduce_shape, reduce_strides); loop.next(simd_lane_id, reduce_shape, reduce_strides);
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
row = in + loop.location(r, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra); thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
loop.next(simd_size, reduce_shape, reduce_strides); loop.next(simd_size, reduce_shape, reduce_strides);
} }
@@ -259,13 +260,14 @@ template <
typename T, typename T,
typename U, typename U,
typename Op, typename Op,
typename IdxT = int64_t,
int N_READS = REDUCE_N_READS, int N_READS = REDUCE_N_READS,
int N_WRITES = REDUCE_N_WRITES> int N_WRITES = REDUCE_N_WRITES>
[[kernel]] void row_reduce_simple( [[kernel]] void row_reduce_simple(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_size [[buffer(2)]],
const constant size_t& out_size [[buffer(3)]], const constant int64_t& out_size [[buffer(3)]],
uint3 gid [[threadgroup_position_in_grid]], uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]], uint3 gsize [[threadgroups_per_grid]],
uint3 lid [[thread_position_in_threadgroup]], uint3 lid [[thread_position_in_threadgroup]],
@@ -277,15 +279,15 @@ template <
U totals[N_WRITES]; U totals[N_WRITES];
// Move to the row // Move to the row
size_t out_idx = N_WRITES * (gid.y + gsize.y * size_t(gid.z)); IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));
if (out_idx + N_WRITES > out_size) { if (out_idx + N_WRITES > out_size) {
out_idx = out_size - N_WRITES; out_idx = out_size - N_WRITES;
} }
in += out_idx * reduction_size; in += out_idx * IdxT(reduction_size);
out += out_idx; out += out_idx;
// Each thread reduces across the row // Each thread reduces across the row
int blocks = reduction_size / (lsize.x * N_READS); int blocks = IdxT(reduction_size) / (lsize.x * N_READS);
int extra = reduction_size - blocks * (lsize.x * N_READS); int extra = reduction_size - blocks * (lsize.x * N_READS);
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>( per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
totals, in, reduction_size, blocks, extra, lsize.x, lid.x); totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
@@ -306,18 +308,19 @@ template <
typename T, typename T,
typename U, typename U,
typename Op, typename Op,
typename IdxT,
int NDIMS, int NDIMS,
int N_READS = REDUCE_N_READS> int N_READS = REDUCE_N_READS>
[[kernel]] void row_reduce_looped( [[kernel]] void row_reduce_looped(
const device T* in [[buffer(0)]], const device T* in [[buffer(0)]],
device U* out [[buffer(1)]], device U* out [[buffer(1)]],
const constant size_t& row_size [[buffer(2)]], const constant int64_t& row_size [[buffer(2)]],
const constant size_t& non_row_reductions [[buffer(3)]], const constant int64_t& non_row_reductions [[buffer(3)]],
const constant int* shape [[buffer(4)]], const constant int* shape [[buffer(4)]],
const constant size_t* strides [[buffer(5)]], const constant int64_t* strides [[buffer(5)]],
const constant int& ndim [[buffer(6)]], const constant int& ndim [[buffer(6)]],
const constant int* reduce_shape [[buffer(7)]], const constant int* reduce_shape [[buffer(7)]],
const constant size_t* reduce_strides [[buffer(8)]], const constant int64_t* reduce_strides [[buffer(8)]],
const constant int& reduce_ndim [[buffer(9)]], const constant int& reduce_ndim [[buffer(9)]],
uint3 gid [[threadgroup_position_in_grid]], uint3 gid [[threadgroup_position_in_grid]],
uint3 gsize [[threadgroups_per_grid]], uint3 gsize [[threadgroups_per_grid]],
@@ -330,19 +333,19 @@ template <
threadgroup U shared_vals[simd_size]; threadgroup U shared_vals[simd_size];
U total = Op::init; U total = Op::init;
size_t out_idx = gid.y + gsize.y * size_t(gid.z); IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
// needs a small refactor. // needs a small refactor.
in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim) + lid.x * N_READS;
looped_elem_to_loc<NDIMS> loop; LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
const device T* row; const device T* row;
int blocks = row_size / (lsize.x * N_READS); int blocks = IdxT(row_size) / (lsize.x * N_READS);
int extra = row_size - blocks * (lsize.x * N_READS); int extra = row_size - blocks * (lsize.x * N_READS);
for (size_t i = 0; i < non_row_reductions; i++) { for (IdxT i = 0; i < non_row_reductions; i++) {
row = in + loop.location(i, reduce_shape, reduce_strides, reduce_ndim); row = in + loop.location();
// Each thread reduces across the row // Each thread reduces across the row
U row_total; U row_total;

View File

@@ -3,8 +3,6 @@
#include <metal_common> #include <metal_common>
#include <metal_simdgroup> #include <metal_simdgroup>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
using namespace metal; using namespace metal;
@@ -17,12 +15,15 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, constant float& eps,
constant uint& axis_size, constant uint& axis_size,
constant uint& w_stride, constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]], uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE];
float acc = 0; float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
@@ -84,13 +85,15 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, constant float& eps,
constant uint& axis_size, constant uint& axis_size,
constant uint& w_stride, constant uint& w_stride,
threadgroup float* local_inv_mean [[threadgroup(0)]],
threadgroup float* local_sums [[threadgroup(1)]],
uint gid [[threadgroup_position_in_grid]], uint gid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]], uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]], uint lsize [[threads_per_threadgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) { uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
constexpr int SIMD_SIZE = 32;
threadgroup float local_inv_mean[1];
threadgroup float local_sums[SIMD_SIZE];
float acc = 0; float acc = 0;
x += gid * size_t(axis_size) + lid * N_READS; x += gid * size_t(axis_size) + lid * N_READS;
w += w_stride * lid * N_READS; w += w_stride * lid * N_READS;
@@ -376,8 +379,6 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, \ constant float& eps, \
constant uint& axis_size, \ constant uint& axis_size, \
constant uint& w_stride, \ constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \
@@ -407,8 +408,6 @@ template <typename T, int N_READS = RMS_N_READS>
constant float& eps, \ constant float& eps, \
constant uint& axis_size, \ constant uint& axis_size, \
constant uint& w_stride, \ constant uint& w_stride, \
threadgroup float* local_inv_mean [[threadgroup(0)]], \
threadgroup float* local_sums [[threadgroup(1)]], \
uint gid [[thread_position_in_grid]], \ uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \ uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \ uint lsize [[threads_per_threadgroup]], \

View File

@@ -2,7 +2,6 @@
#include <metal_math> #include <metal_math>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
template <typename T, bool traditional, bool forward> template <typename T, bool traditional, bool forward>
void rope_single_impl( void rope_single_impl(

View File

@@ -1,946 +1,16 @@
#include <metal_stdlib> #include <metal_stdlib>
#include "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h"
#include "mlx/backend/metal/kernels/sdpa_vector.h" #include "mlx/backend/metal/kernels/sdpa_vector.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
#include "mlx/backend/metal/kernels/steel/gemm/transforms.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
using namespace metal; using namespace metal;
using namespace mlx::steel;
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderFA {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoaderFA(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
METAL_FUNC void next(short n) {
src += n * tile_stride;
}
};
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMAFA {
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = 8 * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Strides of A, B along reduction axis
STEEL_CONST short simd_stride_a = {
transpose_a ? TM_stride : TM_stride * lda_tgp};
STEEL_CONST short simd_stride_b = {
transpose_b ? TN_stride * ldb_tgp : TN_stride};
// Jump between elements
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
// Offsets within threadgroup
const short tm;
const short tn;
short sm;
short sn;
ushort sid;
ushort slid;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMAFA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
// Determine thread position in simdgroup matrix
short qid = simd_lane_id / 4;
slid = simd_lane_id;
sid = simd_group_id;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Determine thread and simdgroup offset
As_offset =
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
Bs_offset =
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
Bsimd[j_serp],
results[i * TN + j_serp]);
}
}
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
METAL_FUNC void rescale_output(const threadgroup float* Corrections) {
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
short row = sm + tm + i * TM_stride;
float scale_value = Corrections[row];
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread auto& accum = results[i * TN + j].thread_elements();
// int offset = (i * TM_stride) * ldc + (j * TN_stride);
accum[0] *= scale_value;
accum[1] *= scale_value;
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C
C[offset] = outs[0];
C[offset + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_to_tgp_memory(
threadgroup U* C,
const int ldc,
short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {
epilogue_op.apply(accum[0], C[offset_c]),
epilogue_op.apply(accum[1], C[offset_c + fdc])};
// Write out D
D[offset_d] = outs[0];
D[offset_d + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
}
}
METAL_FUNC void clear_results() {
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
results[i * TN + j] = simdgroup_matrix<AccumType, 8, 8>(0);
}
}
}
};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_q,
bool transpose_k,
bool transpose_v,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct FastAttentionKernel {
STEEL_CONST short tgp_padding = 16 / sizeof(T);
STEEL_CONST short float_padding = 16 / sizeof(float);
STEEL_CONST short tgp_mem_size_q =
transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_k =
transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_v =
transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding);
STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding);
// maxes, rowsums, rescale
STEEL_CONST short tgp_mem_size_corrections =
4 * (BM * sizeof(float) + float_padding);
STEEL_CONST bool share_kv_smem = transpose_k != transpose_v;
STEEL_CONST short tgp_mem_size = share_kv_smem
? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
tgp_mem_size_corrections
: tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s +
tgp_mem_size_corrections + tgp_mem_size_v;
STEEL_CONST short tgp_size = WM * WN * 32;
static_assert(transpose_q == false, "Expected Q not transposed.");
static_assert(transpose_k == true, "Expected K transposed.");
static_assert(transpose_v == false, "Expected V not transposed.");
static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested.");
using loader_q_t = BlockLoaderFA<
T,
transpose_q ? BK : BM,
transpose_q ? BM : BK,
transpose_q ? BM + tgp_padding : BK + tgp_padding,
!transpose_q,
tgp_size>;
using loader_k_t = BlockLoaderFA<
T,
transpose_k ? BN : BK,
transpose_k ? BK : BN,
transpose_k ? BK + tgp_padding : BN + tgp_padding,
transpose_k,
tgp_size>;
using loader_v_t = BlockLoaderFA<
T,
transpose_v ? BK : BN,
transpose_v ? BN : BK,
transpose_v ? BN + tgp_padding : BK + tgp_padding,
transpose_v,
tgp_size>;
using mma_qk_t = BlockMMAFA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_q,
transpose_k,
transpose_q ? BM + tgp_padding : BK + tgp_padding,
transpose_k ? BK + tgp_padding : BN + tgp_padding,
AccumType,
Epilogue>;
using mma_sv_t = BlockMMAFA<
T,
U,
BM,
BK,
BN,
WM,
WN,
false,
transpose_v,
BN + tgp_padding,
BK + tgp_padding,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_k_t& loader_b,
thread mma_qk_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
(void)tgp_bm;
short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
// not valid for gemm_k_iterations > 1 (so, BK == d_k)
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
}
}
static METAL_FUNC void initialize_corrections(
threadgroup float* C,
uint simd_lane_id,
uint simd_group_id) {
if (simd_group_id == 0) {
threadgroup float* maxes = C;
threadgroup float* sums = C + (BM + float_padding);
threadgroup float* o_rescale = sums + (BM + float_padding);
threadgroup float* output_rescale = o_rescale + (BM + float_padding);
if (simd_lane_id < BM) {
maxes[simd_lane_id] = -INFINITY; // m_i
sums[simd_lane_id] = 0.f; // l_i
o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new)
output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i
}
}
}
static METAL_FUNC void rescale_ss(
threadgroup T* Ss,
threadgroup float* Corrections,
uint simd_group_id,
uint simd_lane_id,
short2 local_blocks,
float alpha) {
if (simd_group_id == 0) {
short row_offset = BM + float_padding;
threadgroup float* maxes = Corrections;
threadgroup float* sums = Corrections + row_offset;
threadgroup float* o_rescale = sums + row_offset;
threadgroup float* output_scales = o_rescale + row_offset;
if (simd_lane_id < uint(local_blocks.y)) {
float m_i_old = maxes[simd_lane_id];
float l_i_old = sums[simd_lane_id];
float m_i_new = m_i_old;
float l_i_new = l_i_old;
short offset = simd_lane_id * (BN + tgp_padding);
float m_ij = -INFINITY;
for (short j = 0; j < local_blocks.x; j++) {
float val = alpha * float(Ss[offset + j]);
m_ij = max(m_ij, val);
}
m_i_new = max(m_ij, m_i_new);
float rowsum = 0.f; // lij
for (short j = 0; j < local_blocks.x; j++) {
float val = alpha * float(Ss[offset + j]);
float P_i_j = exp(val - m_ij);
rowsum += P_i_j;
P_i_j = P_i_j * exp(m_ij - m_i_new);
Ss[offset + j] = T(P_i_j);
}
l_i_new =
exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum;
maxes[simd_lane_id] = m_i_new;
sums[simd_lane_id] = l_i_new;
float rescale = l_i_old * exp(m_i_old - m_i_new);
o_rescale[simd_lane_id] = rescale;
output_scales[simd_lane_id] = 1.0 / l_i_new;
}
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device U* O [[buffer(3)]],
const constant MLXFastAttentionParams* params [[buffer(4)]],
threadgroup T* Qs [[threadgroup(0)]],
threadgroup T* Ks [[threadgroup(1)]],
threadgroup T* Ss [[threadgroup(2)]],
threadgroup T* Vs [[threadgroup(3)]],
threadgroup float* Corrections [[threadgroup(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in Q, O; and head in K, V.
const int c_row = tid_y * BM;
Q += transpose_q ? c_row : c_row * params->ldq;
thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id);
short tgp_bm = min(BM, params->M - c_row);
short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
loader_q.load_safe(tile_dims_Q);
initialize_corrections(Corrections, simd_lane_id, simd_group_id);
O += c_row * params->ldo;
// Prepare threadgroup mma operation
thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id);
thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id);
thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id);
thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id);
for (short n_block = 0; n_block < params->gemm_n_iterations_aligned;
n_block++) {
short c_col = BN;
// Prepare threadgroup loading operations
short gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bn_qk = min(BN, params->N - c_col * n_block);
threadgroup_barrier(mem_flags::mem_none);
///////////////////////////////////////////////////////////////////////////////
{ // Loop over K - unaligned case
if (tgp_bm == BM && tgp_bn_qk == BN) {
gemm_loop<true, true, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else if (tgp_bn_qk == BN) {
gemm_loop<false, true, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
} else {
gemm_loop<false, false, K_aligned>(
Qs,
Ks,
gemm_k_iterations,
loader_k,
mma_qk_op,
tgp_bm,
tgp_bn_qk);
}
}
mma_qk_op.store_result_to_tgp_memory(
Ss, BN + tgp_padding, short2(BN, BM));
threadgroup_barrier(mem_flags::mem_threadgroup);
rescale_ss(
Ss,
Corrections,
simd_group_id,
simd_lane_id,
short2(tgp_bn_qk, tgp_bm),
params->alpha);
loader_v.load_safe(short2(BK, tgp_bn_qk));
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float* o_scales = Corrections + 2 * (BM + float_padding);
mma_softmax_sv_op.rescale_output(o_scales);
mma_softmax_sv_op.mma(Ss, Vs);
threadgroup float* final_output_scales =
Corrections + 3 * (BM + float_padding);
mma_softmax_sv_op.rescale_output(final_output_scales);
loader_v.next();
loader_k.next(BN);
mma_qk_op.clear_results();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm));
}
};
template <
typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_q,
bool transpose_k,
bool transpose_v,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant MLXFastAttentionParams* params [[buffer(4)]],
const constant int* batch_shape [[buffer(6)]],
const constant size_t* batch_strides [[buffer(7)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using attention_kernel = FastAttentionKernel<
T,
T,
BM,
BN,
BK,
WM,
WN,
transpose_q,
transpose_k,
transpose_v,
MN_aligned,
K_aligned>;
// Adjust for batch
if (params->batch_ndim > 1) {
const constant size_t* Q_bstrides = batch_strides;
const constant size_t* KV_bstrides = batch_strides + params->batch_ndim;
ulong2 batch_offsets = elem_to_loc_broadcast(
tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim);
Q += batch_offsets.x;
K += batch_offsets.y;
V += batch_offsets.y;
} else {
Q += params->batch_stride_q * tid.z;
K += params->batch_stride_k * tid.z;
V += params->batch_stride_v * tid.z;
}
// same shape as input
O += params->batch_stride_o * tid.z;
threadgroup T Qs[attention_kernel::tgp_mem_size_q];
threadgroup T Ss[attention_kernel::tgp_mem_size_s];
threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections];
if (attention_kernel::share_kv_smem) {
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v];
attention_kernel::run(
Q,
K,
V,
O,
params,
Qs,
Ks,
Ss,
Vs,
Corrections,
simd_lane_id,
simd_group_id,
tid,
lid);
} else {
threadgroup T Ks[attention_kernel::tgp_mem_size_k];
threadgroup T Vs[attention_kernel::tgp_mem_size_v];
attention_kernel::run(
Q,
K,
V,
O,
params,
Qs,
Ks,
Ss,
Vs,
Corrections,
simd_lane_id,
simd_group_id,
tid,
lid);
}
}
// clang-format off // clang-format off
// SDPA full instantiations
#define instantiate_fast_inference_self_attention_kernel( \
itype, otype, bm, bn, bk, wm, wn) \
template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \
"_itype_" #itype)]] [[kernel]] void \
attention<itype, bm, bn, bk, wm, wn, false, true, false, false, true>( \
const device itype* Q [[buffer(0)]], \
const device itype* K [[buffer(1)]], \
const device itype* V [[buffer(2)]], \
device otype* O [[buffer(3)]], \
const constant MLXFastAttentionParams* params [[buffer(4)]], \
const constant int* batch_shape [[buffer(6)]], \
const constant size_t* batch_strides [[buffer(7)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
instantiate_fast_inference_self_attention_kernel(
float,
float,
16,
16,
64,
2,
2);
instantiate_fast_inference_self_attention_kernel(
float,
float,
16,
16,
128,
2,
2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2);
instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2);
// SDPA vector instantiations // SDPA vector instantiations
#define instantiate_sdpa_vector(type, head_dim) \ #define instantiate_sdpa_vector(type, head_dim) \
template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \ instantiate_kernel("sdpa_vector_" #type "_" #head_dim, sdpa_vector, type, head_dim) \
[[kernel]] void sdpa_vector<type, head_dim>( \ instantiate_kernel("sdpa_vector_2pass_1_" #type "_" #head_dim, sdpa_vector_2pass_1, type, head_dim) \
const device type* queries [[buffer(0)]], \ instantiate_kernel("sdpa_vector_2pass_2_" #type "_" #head_dim, sdpa_vector_2pass_2, type, head_dim)
const device type* keys [[buffer(1)]], \
const device type* values [[buffer(2)]], \
device type* out [[buffer(3)]], \
const constant int& gqa_factor, \
const constant int& N, \
const constant size_t& k_stride, \
const constant size_t& v_stride, \
const constant float& scale, \
uint3 tid [[threadgroup_position_in_grid]], \
uint simd_gid [[simdgroup_index_in_threadgroup]], \
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_sdpa_vector_heads(type) \ #define instantiate_sdpa_vector_heads(type) \
instantiate_sdpa_vector(type, 64) \ instantiate_sdpa_vector(type, 64) \

View File

@@ -1,42 +0,0 @@
//
// scaled_dot_product_attention_params.h
// mlx
#pragma once
struct MLXFastAttentionParams {
const int M;
const int N;
const int K;
const int ldq; // ldq == ldo
const int ldk;
const int ldv;
const int lds;
const int ldo;
const int tiles_n;
const int tiles_m;
const int batch_stride_q;
const int batch_stride_k;
const int batch_stride_v;
const int batch_stride_o;
const int swizzle_log;
const int gemm_n_iterations_aligned;
const int gemm_k_iterations_aligned;
const int gemm_sv_m_block_iterations;
const int batch_ndim;
const float alpha;
};
struct MLXScaledDotProductAttentionParams {
// Associated dimensions & transposition information
const uint QUERY_SEQUENCE_LENGTH = 1;
const uint N_Q_HEADS = 32;
const uint N_KV_HEADS = 32;
const uint KV_TILES = 1;
const float INV_ALPHA = 0.08838834764831843f;
};

View File

@@ -10,16 +10,17 @@ template <
typename Op, typename Op,
int NIDX, int NIDX,
bool UPD_ROW_CONTIG, bool UPD_ROW_CONTIG,
int NWORK> int NWORK,
typename LocT>
METAL_FUNC void scatter_impl( METAL_FUNC void scatter_impl(
const device T* updates, const device T* updates,
device mlx_atomic<T>* out, device mlx_atomic<T>* out,
const constant int* upd_shape, const constant int* upd_shape,
const constant size_t* upd_strides, const constant int64_t* upd_strides,
const constant size_t& upd_ndim, const constant size_t& upd_ndim,
const constant size_t& upd_size, const constant size_t& upd_size,
const constant int* out_shape, const constant int* out_shape,
const constant size_t* out_strides, const constant int64_t* out_strides,
const constant size_t& out_ndim, const constant size_t& out_ndim,
const constant int* axes, const constant int* axes,
const constant size_t& idx_size, const constant size_t& idx_size,
@@ -28,29 +29,30 @@ METAL_FUNC void scatter_impl(
Op op; Op op;
auto ind_idx = gid.y * NWORK; auto ind_idx = gid.y * NWORK;
size_t out_offset = 0; LocT out_offset = 0;
if (upd_size > 1) { if (upd_size > 1) {
out_offset = out_offset = elem_to_loc<LocT>(
elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim); gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
} }
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) { for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
size_t out_idx = out_offset; LocT out_idx = out_offset;
for (int i = 0; i < NIDX; ++i) { for (int i = 0; i < NIDX; ++i) {
auto idx_loc = indices.row_contiguous[i] auto idx_loc = indices.row_contiguous[i]
? ind_idx ? ind_idx
: elem_to_loc( : elem_to_loc<LocT>(
ind_idx, ind_idx,
&indices.shapes[indices.ndim * i], &indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i], &indices.strides[indices.ndim * i],
indices.ndim); indices.ndim);
auto ax = axes[i]; auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax]; out_idx +=
static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
} }
auto upd_idx = ind_idx * upd_size + gid.x; auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
if constexpr (!UPD_ROW_CONTIG) { if constexpr (!UPD_ROW_CONTIG) {
upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); upd_idx = elem_to_loc<LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
} }
op.atomic_update(out, updates[upd_idx], out_idx); op.atomic_update(out, updates[upd_idx], out_idx);
} }

View File

@@ -21,8 +21,7 @@ template <typename T, int D>
constexpr int BN = 32; constexpr int BN = 32;
constexpr int BD = 32; constexpr int BD = 32;
constexpr int elem_per_thread = D / BD; constexpr int elem_per_thread = D / BD;
constexpr int stride = BN * D;
const int stride = BN * D;
typedef float U; typedef float U;
@@ -84,7 +83,6 @@ template <typename T, int D>
keys += stride; keys += stride;
values += stride; values += stride;
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each thread has a partial part of the output so we need to combine them. // Each thread has a partial part of the output so we need to combine them.
@@ -114,3 +112,181 @@ template <typename T, int D>
} }
} }
} }
template <typename T, int D>
[[kernel]] void sdpa_vector_2pass_1(
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device float* out [[buffer(3)]],
device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]],
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int stride = BN * D;
constexpr int blocks = 32;
typedef float U;
thread U q[elem_per_thread];
thread U k[elem_per_thread];
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
// Adjust positions
const int block_idx = tid.z;
const int head_idx = tid.y;
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + (block_idx * BN + simd_gid) * D +
simd_lid * elem_per_thread;
out += head_idx * blocks * D + block_idx * D + simd_lid * elem_per_thread;
sums += head_idx * blocks + block_idx;
maxs += head_idx * blocks + block_idx;
// Read the query and 0 the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
}
for (int i = 0; i < elem_per_thread; i++) {
o[i] = 0;
}
U max_score = -1e9;
U sum_exp_score = 0;
// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
// Read the key
for (int i = 0; i < elem_per_thread; i++) {
k[i] = keys[i];
}
// Compute the i-th score
U score = 0;
for (int i = 0; i < elem_per_thread; i++) {
score += q[i] * k[i];
}
score = simd_sum(score);
// Update the accumulators
U new_max = max(max_score, score);
U factor = fast::exp(max_score - new_max);
U exp_score = fast::exp(score - new_max);
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;
// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] = o[i] * factor + exp_score * values[i];
}
// Move the pointers to the next kv
keys += blocks * stride;
values += blocks * stride;
}
// Each thread has a partial part of the output so we need to combine them.
// First let's communicate the max and sum_exp
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
sum_exp_score = simd_sum(sum_exp_score * factor);
// Write the sum and new max
if (simd_gid == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
}
// Now we need to aggregate all the outputs
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BN + simd_gid] =
o[i] * fast::exp(max_scores[simd_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);
// And write the output
if (simd_gid == 0) {
U output = outputs[simd_lid * BN];
for (int j = 1; j < BN; j++) {
output += outputs[simd_lid * BN + j];
}
out[i] = static_cast<T>(output);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
template <typename T, int D>
[[kernel]] void sdpa_vector_2pass_2(
const device float* partials [[buffer(0)]],
const device float* sums [[buffer(1)]],
const device float* maxs [[buffer(2)]],
device T* out [[buffer(3)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int blocks = 32;
typedef float U;
thread U o[elem_per_thread];
threadgroup U outputs[BN * BD];
// Adjust positions
const int head_idx = tid.y;
partials += head_idx * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += head_idx * blocks;
maxs += head_idx * blocks;
out += head_idx * D + simd_gid * elem_per_thread;
// First everybody reads the max and sum_exp
U max_score = maxs[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
// Now read the block into registers and then use shared memory to transpose
// it
for (int i = 0; i < elem_per_thread; i++) {
o[i] = partials[i];
}
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// And write the output
if (simd_lid == 0) {
for (int i = 0; i < elem_per_thread; i++) {
out[i] = static_cast<T>(o[i]);
}
}
}

View File

@@ -6,8 +6,6 @@
using namespace metal; using namespace metal;
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/softmax.h" #include "mlx/backend/metal/kernels/softmax.h"

View File

@@ -343,8 +343,8 @@ template <
const constant int& out_stride_sorted_axis [[buffer(4)]], const constant int& out_stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]], const constant int& nc_dim [[buffer(5)]],
const constant int* nc_shape [[buffer(6)]], const constant int* nc_shape [[buffer(6)]],
const constant size_t* in_nc_strides [[buffer(7)]], const constant int64_t* in_nc_strides [[buffer(7)]],
const constant size_t* out_nc_strides [[buffer(8)]], const constant int64_t* out_nc_strides [[buffer(8)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = using sort_kernel =
@@ -486,7 +486,7 @@ template <
const constant int& stride_sorted_axis [[buffer(4)]], const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]], const constant int& nc_dim [[buffer(5)]],
const constant int* nc_shape [[buffer(6)]], const constant int* nc_shape [[buffer(6)]],
const constant size_t* nc_strides [[buffer(7)]], const constant int64_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { uint3 lid [[thread_position_in_threadgroup]]) {
using sort_kernel = KernelMultiBlockMergeSort< using sort_kernel = KernelMultiBlockMergeSort<

View File

@@ -3,8 +3,6 @@
#include <metal_stdlib> #include <metal_stdlib>
// clang-format off // clang-format off
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/sort.h" #include "mlx/backend/metal/kernels/sort.h"

View File

@@ -0,0 +1,296 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/attn/loader.h"
#include "mlx/backend/metal/kernels/steel/attn/mma.h"
#include "mlx/backend/metal/kernels/steel/attn/params.h"
#include "mlx/backend/metal/kernels/steel/attn/transforms.h"
#include "mlx/backend/metal/kernels/steel/gemm/params.h"
#include "mlx/backend/metal/kernels/steel/utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel class
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct GEMMKernel {
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
STEEL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
STEEL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
STEEL_CONST short tgp_size = WM * WN * 32;
using loader_a_t = BlockLoader<
T,
transpose_a ? BK : BM,
transpose_a ? BM : BK,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
!transpose_a,
tgp_size>;
using loader_b_t = BlockLoader<
T,
transpose_b ? BN : BK,
transpose_b ? BK : BN,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
transpose_b,
tgp_size>;
using mma_t = BlockMMA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
thread mma_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
thread const short& lbk,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned_) {
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* D [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const size_t c_row_long = size_t(c_row);
const size_t c_col_long = size_t(c_col);
A += transpose_a ? c_row_long : c_row_long * params->lda;
B += transpose_b ? c_col_long * params->ldb : c_col_long;
D += c_row_long * params->ldd + c_col_long;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_loop<true, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result(D, params->ldd);
return;
} else if (tgp_bn == BN) {
gemm_loop<false, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
} else {
gemm_loop<false, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
return;
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@@ -0,0 +1,349 @@
// Copyright © 2024 Apple Inc.
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
constant bool align_Q [[function_constant(200)]];
constant bool align_K [[function_constant(201)]];
template <typename T>
struct TransformScale {
T scale;
METAL_FUNC TransformScale(T scale_) : scale(scale_) {}
METAL_FUNC T apply(T x) const {
return scale * x;
}
};
struct MaxOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return metal::max(x, y);
}
};
struct SumOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x + y;
}
};
struct MulOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x * y;
}
};
struct SubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x - y;
}
};
struct ExpSubOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return fast::exp(x - y);
}
};
struct DivOp {
template <typename T>
METAL_FUNC static constexpr T apply(T x, T y) {
return x / y;
}
};
// clang-format off
template <
typename T,
int BQ,
int BK,
int BD,
int WM,
int WN,
typename AccumType = float>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
const device T* Q [[buffer(0)]],
const device T* K [[buffer(1)]],
const device T* V [[buffer(2)]],
device T* O [[buffer(3)]],
const constant AttnParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on
// Pacifying compiler
(void)lid;
// Move to correct block
ulong3 tidl{tid.x, tid.y, tid.z};
Q += tidl.z * params->Q_strides[0] + // Batch
tidl.y * params->Q_strides[1] + // Head
tidl.x * BQ * params->Q_strides[2]; // Seqeunce
ulong kv_head_idx = int(tid.y) / params->gqa_factor;
K += tidl.z * params->K_strides[0] + // Batch
kv_head_idx * params->K_strides[1]; // Head
V += tidl.z * params->V_strides[0] + // Batch
kv_head_idx * params->V_strides[1]; // Head
O += tidl.z * params->O_strides[0] + // Batch
tidl.y * params->O_strides[1] + // Head
tidl.x * BQ * params->O_strides[2]; // Seqeunce
// Prepare threadgroup memory
constexpr short padQ = 0; // 16 / sizeof(T);
constexpr short padK = 0; // 16 / sizeof(T);
constexpr short padV = 0; // 16 / sizeof(T);
constexpr short LDQ_tgp = BD + padQ;
constexpr short LDK_tgp = BK + padK;
constexpr short LDV_tgp = BD + padV;
threadgroup T Qs[BQ * (BD + padQ)];
threadgroup T Ks[(BK + padK) * BD];
threadgroup T Vs[BK * (BD + padV)];
// Prepare block loaders
using QBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BQ,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDQ_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 1,
/* short tgp_size = */ WM * WN * 32>;
// K is loaded in transposed
using KBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ 1,
/* short kDstStrCol = */ LDK_tgp,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
using VBlockLoader = BlockLoaderT<
/* typename T = */ T,
/* short BROWS = */ BK,
/* short BCOLS = */ BD,
/* short kDstStrRow = */ LDV_tgp,
/* short kDstStrCol = */ 1,
/* short reduction_dim = */ 0,
/* short tgp_size = */ WM * WN * 32>;
QBlockLoader loader_q(
Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id);
KBlockLoader loader_k(
K, params->K_strides[2], Ks, simd_group_id, simd_lane_id);
VBlockLoader loader_v(
V, params->V_strides[2], Vs, simd_group_id, simd_lane_id);
TransformScale<T> ts(static_cast<T>(params->scale));
// Prepare MMA tiles
constexpr short kFragSize = 8; // MMAFrag size
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
constexpr int kNWarps = WM * WN;
static_assert(
BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0,
"Each simdgroup must host atleast 1 simdgroup matrix along Q sequence.");
// Q seq frags per warp
constexpr int TQ = BQ / (kNWarps * kFragSize);
// KV sequence frags (all warps load the same frags)
constexpr int TK = BK / kFragSize;
// HeadDim frags (all warps load the same frags)
constexpr int TD = BD / kFragSize;
static_assert(TQ == 1, "Check TQ");
MMATile<AccumType, TQ, 1, MMAFrag_acc_t> Qtile;
MMATile<AccumType, 1, TK, MMAFrag_acc_t> Ktile;
MMATile<AccumType, TQ, TK, MMAFrag_acc_t> Stile;
MMATile<AccumType, TK, TD, MMAFrag_acc_t> Vtile;
MMATile<AccumType, TQ, TD, MMAFrag_acc_t> Otile;
Otile.clear();
// Prepare mma tile offsets
const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
const short sm = simd_coord.y;
const short sn = simd_coord.x;
const short tm = kFragSize * TQ * simd_group_id;
const short Qs_offset = (tm + sm) * LDQ_tgp + sn;
const short Ks_offset = sm * LDK_tgp + sn;
const short Vs_offset = sm * LDV_tgp + sn;
constexpr short Qs_tile_stride = kFragSize;
constexpr short Ks_tile_stride = kFragSize * LDK_tgp;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load Q blocks apply scale
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
} else {
loader_q.load_unsafe();
}
loader_q.apply_inplace_op(ts);
// Init row reduction variables
constexpr short kRowsPT = decltype(Stile)::kRowsPerThread;
AccumType max_score[kRowsPT];
AccumType sum_score[kRowsPT] = {0};
// Init to -Inf
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = Limits<AccumType>::min;
}
// Loop over KV seq length
for (int kb = 0; kb < params->NK; kb++) {
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!align_K && kb == (params->NK_aligned)) {
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
} else {
loader_k.load_unsafe();
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do S = Q @ K.T
Stile.clear();
for (short dd = 0; dd < TD; dd++) {
simdgroup_barrier(mem_flags::mem_none);
Qtile.template load<T, 1, 1, LDQ_tgp, 1>(
&Qs[Qs_offset + dd * Qs_tile_stride]);
Ktile.template load<T, 1, 1, LDK_tgp, 1>(
&Ks[Ks_offset + dd * Ks_tile_stride]);
simdgroup_barrier(mem_flags::mem_none);
tile_matmad(Stile, Qtile, Ktile, Stile);
}
// Mask out of length sequence
if (!align_K && kb == (params->NK_aligned)) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = -metal::numeric_limits<selem_t>::infinity();
const short lim = params->kL - params->NK_aligned * BK;
STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
short col_pos = sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if ((col_pos + jj) >= lim) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}
simdgroup_barrier(mem_flags::mem_none);
// Load V blocks
if (!align_K && kb == (params->NK_aligned)) {
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
} else {
loader_v.load_unsafe();
}
// Do softmax
// Temp variables
AccumType new_max[kRowsPT];
AccumType factor[kRowsPT];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
new_max[i] = max_score[i];
}
// Row max
Stile.template row_reduce<MaxOp>(new_max);
// exp(Si - rowmax(Si))
Stile.template row_bin_op<ExpSubOp>(new_max);
// Factor exp(rowmax(Si) - rowmax(Si-1))
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
factor[i] = fast::exp(max_score[i] - new_max[i]);
}
// Save max for next iteration
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
max_score[i] = new_max[i];
}
// Row Sum
AccumType sum_score_tmp[kRowsPT] = {0};
Stile.template row_reduce<SumOp>(sum_score_tmp);
// Update norm
STEEL_PRAGMA_UNROLL
for (short i = 0; i < kRowsPT; ++i) {
sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i];
}
// Update O
Otile.template row_bin_op<MulOp>(factor);
// Load V into registers
threadgroup_barrier(mem_flags::mem_threadgroup);
Vtile.template load<T, 1, 1, LDV_tgp, 1>(&Vs[Vs_offset]);
simdgroup_barrier(mem_flags::mem_none);
// Do O = S @ V
tile_matmad(Otile, Stile, Vtile, Otile);
// Prepare for next iteration
loader_k.next();
loader_v.next();
}
// Normalize output
Otile.template row_bin_op<DivOp>(sum_score);
threadgroup_barrier(mem_flags::mem_none);
// Store results
O += (tm + sm) * params->O_strides[2] + sn;
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
auto dst_tile_dims =
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
return;
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
} else {
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
}
}

View File

@@ -0,0 +1,31 @@
// Copyright © 2024 Apple Inc.
// clang-format off
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
const device dtype* Q [[buffer(0)]], \
const device dtype* K [[buffer(1)]], \
const device dtype* V [[buffer(2)]], \
device dtype* O [[buffer(3)]],\
const constant AttnParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_attn_shapes_helper(iname, itype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
instantiate_attn_shapes_helper(float16, half);
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
instantiate_attn_shapes_helper(float32, float);
// clang-format on

View File

@@ -0,0 +1,264 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/backend/metal/kernels/steel/defines.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
template <int R, int C>
struct CShape {
STEEL_CONST int kRows = R;
STEEL_CONST int kCols = C;
};
template <
typename T,
short BROWS,
short BCOLS,
short kDstStrRow,
short kDstStrCol,
short reduction_dim,
short tgp_size,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoaderT {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
/* Constructor */
METAL_FUNC BlockLoaderT(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
src(src_ + bi * src_ld + bj) {}
/* Apply operation to threadgroup without bound checking */
template <typename UnaryOp>
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] =
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
}
}
}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
}
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx

Some files were not shown because too many files have changed in this diff Show More