mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
Compare commits
41 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
cf236fc390 | ||
![]() |
27d70c7d9d | ||
![]() |
0e585b4409 | ||
![]() |
0163a8e57a | ||
![]() |
578842954c | ||
![]() |
496315fe1d | ||
![]() |
0fe6895893 | ||
![]() |
0b7d71fd2f | ||
![]() |
83b11bc58d | ||
![]() |
375a8bbdcc | ||
![]() |
ea9090bbc4 | ||
![]() |
81def6ac76 | ||
![]() |
3de8ce3f3c | ||
![]() |
4d485fca24 | ||
![]() |
1865299a30 | ||
![]() |
3576b547c5 | ||
![]() |
079882495d | ||
![]() |
ab977109db | ||
![]() |
fd1c08137b | ||
![]() |
76b6cece46 | ||
![]() |
9f0df51f8d | ||
![]() |
e7a2a3dcd1 | ||
![]() |
a87ef5bfc1 | ||
![]() |
9f9cb7a2ef | ||
![]() |
7e26fd8032 | ||
![]() |
eab2685c67 | ||
![]() |
50dfb664db | ||
![]() |
0189ab6ab6 | ||
![]() |
9401507336 | ||
![]() |
eb8321d863 | ||
![]() |
79ef49b2c2 | ||
![]() |
e110ca11e2 | ||
![]() |
226748b3e7 | ||
![]() |
d568c7ee36 | ||
![]() |
e6fecbb3e1 | ||
![]() |
da83f899bb | ||
![]() |
7e5674d8be | ||
![]() |
0a558577bf | ||
![]() |
fb71a82ada | ||
![]() |
23406c9e9e | ||
![]() |
b3ec792380 |
@@ -71,6 +71,7 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.8
|
||||
brew install openmpi
|
||||
python3.8 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
@@ -96,10 +97,14 @@ jobs:
|
||||
source env/bin/activate
|
||||
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
||||
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
||||
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
||||
- run:
|
||||
name: Build example extension
|
||||
command: |
|
||||
cd examples/extensions && python3.8 -m pip install .
|
||||
source env/bin/activate
|
||||
cd examples/extensions
|
||||
pip install -r requirements.txt
|
||||
python setup.py build_ext -j8
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- run:
|
||||
@@ -111,7 +116,13 @@ jobs:
|
||||
name: Run CPP tests
|
||||
command: |
|
||||
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
||||
DEVICE=cpu ./build/tests/tests
|
||||
- run:
|
||||
name: Build small binary
|
||||
command: |
|
||||
source env/bin/activate
|
||||
cd build/
|
||||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=OFF -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON
|
||||
make -j
|
||||
|
||||
build_release:
|
||||
parameters:
|
||||
|
@@ -16,6 +16,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@@ -20,10 +20,11 @@ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
||||
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
||||
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
||||
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
||||
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.13.1)
|
||||
set(MLX_VERSION 0.15.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -109,7 +110,7 @@ elseif (MLX_BUILD_METAL)
|
||||
$<INSTALL_INTERFACE:include/metal_cpp>
|
||||
)
|
||||
target_link_libraries(
|
||||
mlx
|
||||
mlx PUBLIC
|
||||
${METAL_LIB}
|
||||
${FOUNDATION_LIB}
|
||||
${QUARTZ_LIB})
|
||||
@@ -122,7 +123,7 @@ if (MLX_BUILD_CPU)
|
||||
if (MLX_BUILD_ARM AND ACCELERATE_LIBRARY)
|
||||
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
||||
set(MLX_BUILD_ACCELERATE ON)
|
||||
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
|
||||
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
message(STATUS "Accelerate or arm neon not found, using default backend.")
|
||||
@@ -145,7 +146,7 @@ if (MLX_BUILD_CPU)
|
||||
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
||||
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${LAPACK_LIBRARIES})
|
||||
target_link_libraries(mlx PUBLIC ${LAPACK_LIBRARIES})
|
||||
# List blas after lapack otherwise we may accidentally incldue an old version
|
||||
# of lapack.h from the include dirs of blas.
|
||||
find_package(BLAS REQUIRED)
|
||||
@@ -160,12 +161,17 @@ if (MLX_BUILD_CPU)
|
||||
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
||||
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
||||
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
||||
target_link_libraries(mlx ${BLAS_LIBRARIES})
|
||||
target_link_libraries(mlx PUBLIC ${BLAS_LIBRARIES})
|
||||
endif()
|
||||
else()
|
||||
set(MLX_BUILD_ACCELERATE OFF)
|
||||
endif()
|
||||
|
||||
find_package(MPI)
|
||||
if (MPI_FOUND)
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
||||
target_include_directories(
|
||||
@@ -175,6 +181,14 @@ target_include_directories(
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
FetchContent_Declare(fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
||||
GIT_TAG 10.2.1
|
||||
EXCLUDE_FROM_ALL
|
||||
)
|
||||
FetchContent_MakeAvailable(fmt)
|
||||
target_link_libraries(mlx PRIVATE fmt::fmt-header-only)
|
||||
|
||||
if (MLX_BUILD_PYTHON_BINDINGS)
|
||||
message(STATUS "Building Python bindings.")
|
||||
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
|
||||
|
@@ -88,13 +88,13 @@ for more information on building the C++ and Python APIs from source.
|
||||
|
||||
## Contributing
|
||||
|
||||
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
||||
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
||||
on contributing to MLX. See the
|
||||
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
||||
information on building from source, and running tests.
|
||||
|
||||
We are grateful for all of [our
|
||||
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
||||
to MLX and wish to be acknowledged, please add your name to the list in your
|
||||
pull request.
|
||||
|
||||
|
@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
|
||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
return torch.nn.functional.mish(y)
|
||||
y = torch.nn.functional.mish(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@@ -283,6 +283,14 @@ def topk(axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step_function(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.where(y < 0, 0, 1)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def selu(x):
|
||||
y = x
|
||||
@@ -446,5 +454,11 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
elif args.benchmark == "step":
|
||||
print(bench(step_function, x))
|
||||
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown benchmark")
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
|
@@ -16,7 +16,9 @@ def run_or_raise(*args, **kwargs):
|
||||
result = run(*args, capture_output=True, **kwargs)
|
||||
return float(result.stdout)
|
||||
except ValueError:
|
||||
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
|
||||
raise ValueError(
|
||||
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
|
||||
)
|
||||
|
||||
|
||||
def compare(args):
|
||||
|
@@ -9,7 +9,6 @@ from time_utils import time_fn
|
||||
|
||||
|
||||
def bench_gelu():
|
||||
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
@@ -51,7 +50,6 @@ def bench_gelu():
|
||||
|
||||
|
||||
def bench_layernorm():
|
||||
|
||||
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
mx.eval(weight, bias)
|
||||
|
@@ -28,11 +28,11 @@ def bench(f, a, b):
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding)
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
@@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding)
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
@@ -53,11 +53,12 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
@@ -67,15 +68,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding)
|
||||
f_pt = make_pt_conv_2D(strides, padding)
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
@@ -84,7 +85,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
@@ -95,35 +96,40 @@ if __name__ == "__main__":
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
|
||||
for N, H, W, C, kH, kW, O, strides, padding in shapes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, np_dtype
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
|
@@ -3,6 +3,8 @@
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import sympy
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
@@ -16,41 +18,100 @@ def bandwidth_gb(runtime_ms, system_size):
|
||||
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||
|
||||
|
||||
def run_bench(system_size):
|
||||
def fft(x):
|
||||
out = mx.fft.fft(x)
|
||||
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
|
||||
def fft_mlx(x):
|
||||
if dim == 1:
|
||||
out = mx.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = mx.fft.fft2(x)
|
||||
mx.eval(out)
|
||||
return out
|
||||
|
||||
bandwidths = []
|
||||
for k in range(4, 12):
|
||||
n = 2**k
|
||||
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
|
||||
x = x.astype(mx.complex64)
|
||||
mx.eval(x)
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
|
||||
def fft_mps(x):
|
||||
if dim == 1:
|
||||
out = torch.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = torch.fft.fft2(x)
|
||||
torch.mps.synchronize()
|
||||
return out
|
||||
|
||||
return bandwidths
|
||||
bandwidths = []
|
||||
for n in fft_sizes:
|
||||
batch_size = system_size // n**dim
|
||||
shape = [batch_size] + [n for _ in range(dim)]
|
||||
if backend == "mlx":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = mx.array(x_np)
|
||||
mx.eval(x)
|
||||
fft = fft_mlx
|
||||
elif backend == "mps":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = torch.tensor(x_np, device="mps")
|
||||
torch.mps.synchronize()
|
||||
fft = fft_mps
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
|
||||
print(n, bandwidth)
|
||||
bandwidths.append(bandwidth)
|
||||
|
||||
return np.array(bandwidths)
|
||||
|
||||
|
||||
def time_fft():
|
||||
x = np.array(range(2, 512))
|
||||
system_size = int(2**26)
|
||||
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=int(2**22))
|
||||
|
||||
print("MLX GPU")
|
||||
with mx.stream(mx.gpu):
|
||||
gpu_bandwidths = run_bench(system_size=int(2**29))
|
||||
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
# plot bandwidths
|
||||
x = [2**k for k in range(4, 12)]
|
||||
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
|
||||
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
|
||||
plt.title("MLX FFT Benchmark")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig("fft_plot.png")
|
||||
print("MPS GPU")
|
||||
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
|
||||
|
||||
print("CPU")
|
||||
system_size = int(2**20)
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
x = np.array(x)
|
||||
|
||||
all_indices = x - x[0]
|
||||
radix_2to13 = (
|
||||
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
bluesteins = (
|
||||
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
|
||||
for indices, name in [
|
||||
(all_indices, "All"),
|
||||
(radix_2to13, "Radix 2-13"),
|
||||
(bluesteins, "Bluestein's"),
|
||||
]:
|
||||
# plot bandwidths
|
||||
print(name)
|
||||
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
|
||||
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
|
||||
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
|
||||
plt.title(f"MLX FFT Benchmark -- {name}")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig(f"{name}.png")
|
||||
plt.clf()
|
||||
|
||||
av_gpu_bandwidth = np.mean(gpu_bandwidths)
|
||||
av_mps_bandwidth = np.mean(mps_bandwidths)
|
||||
av_cpu_bandwidth = np.mean(cpu_bandwidths)
|
||||
print("Average bandwidths:")
|
||||
print("GPU:", av_gpu_bandwidth)
|
||||
print("MPS:", av_mps_bandwidth)
|
||||
print("CPU:", av_cpu_bandwidth)
|
||||
|
||||
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
|
||||
print("Percent MLX faster than MPS: ", portion_faster * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
62
benchmarks/python/sdpa_bench.py
Normal file
62
benchmarks/python/sdpa_bench.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
MAX_SEQ = 300
|
||||
START_SEQ = 100
|
||||
SEQ_INCREMENT = 50
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
def sdpa_primitives(qs, ks, vs, alpha):
|
||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ vs
|
||||
return o
|
||||
|
||||
time_fn(sdpa_primitives, q, k, v, scale)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
def sdpa_fused(qs, ks, vs, alpha):
|
||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
||||
return o
|
||||
|
||||
time_fn(sdpa_fused, q, k, v, scale)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
@@ -43,6 +43,7 @@ are the CPU and GPU.
|
||||
usage/function_transforms
|
||||
usage/compile
|
||||
usage/numpy
|
||||
usage/distributed
|
||||
usage/using_streams
|
||||
|
||||
.. toctree::
|
||||
@@ -69,6 +70,7 @@ are the CPU and GPU.
|
||||
python/metal
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
python/tree_utils
|
||||
|
||||
.. toctree::
|
||||
|
@@ -163,6 +163,8 @@ should point to the path to the built metal library.
|
||||
- ON
|
||||
* - MLX_BUILD_GGUF
|
||||
- ON
|
||||
* - MLX_METAL_JIT
|
||||
- OFF
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -184,21 +186,30 @@ should point to the path to the built metal library.
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
|
||||
and `BUILD_SHARED_LIBS=ON`.
|
||||
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
||||
and ``BUILD_SHARED_LIBS=ON``.
|
||||
|
||||
The MLX CMake build has several additional options to make smaller binaries.
|
||||
For example, if you don't need the CPU backend or support for safetensors and
|
||||
GGUF, you can do:
|
||||
|
||||
```shell
|
||||
cmake .. \
|
||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=ON \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF
|
||||
```
|
||||
.. code-block:: shell
|
||||
|
||||
cmake ..
|
||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
-DMLX_BUILD_SAFETENSORS=OFF \
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
|
||||
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
||||
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists accross reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
19
docs/src/python/distributed.rst
Normal file
19
docs/src/python/distributed.rst
Normal file
@@ -0,0 +1,19 @@
|
||||
.. _distributed:
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
Distributed Communication
|
||||
==========================
|
||||
|
||||
MLX provides a distributed communication package using MPI. The MPI library is
|
||||
loaded at runtime; if MPI is available then distributed communication is also
|
||||
made available.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Group
|
||||
is_available
|
||||
init
|
||||
all_sum
|
||||
all_gather
|
@@ -10,5 +10,6 @@ Linear Algebra
|
||||
|
||||
inv
|
||||
norm
|
||||
cholesky
|
||||
qr
|
||||
svd
|
||||
|
@@ -17,6 +17,8 @@ simple functions.
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
glu
|
||||
hard_shrink
|
||||
hard_tanh
|
||||
hardswish
|
||||
leaky_relu
|
||||
log_sigmoid
|
||||
@@ -29,6 +31,7 @@ simple functions.
|
||||
sigmoid
|
||||
silu
|
||||
softmax
|
||||
softmin
|
||||
softplus
|
||||
softshrink
|
||||
step
|
||||
|
@@ -21,10 +21,15 @@ Layers
|
||||
Dropout3d
|
||||
Embedding
|
||||
GELU
|
||||
GLU
|
||||
GroupNorm
|
||||
GRU
|
||||
HardShrink
|
||||
HardTanh
|
||||
Hardswish
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
LeakyReLU
|
||||
Linear
|
||||
LSTM
|
||||
MaxPool1d
|
||||
@@ -36,13 +41,19 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
SELU
|
||||
Sequential
|
||||
SiLU
|
||||
SinusoidalPositionalEncoding
|
||||
Softmin
|
||||
Softshrink
|
||||
Softsign
|
||||
Softmax
|
||||
Softplus
|
||||
Step
|
||||
Tanh
|
||||
Transformer
|
||||
Upsample
|
||||
|
@@ -35,7 +35,6 @@ Operations
|
||||
bitwise_or
|
||||
bitwise_xor
|
||||
block_masked_mm
|
||||
block_sparse_mm
|
||||
broadcast_to
|
||||
ceil
|
||||
clip
|
||||
@@ -69,6 +68,8 @@ Operations
|
||||
floor
|
||||
floor_divide
|
||||
full
|
||||
gather_mm
|
||||
gather_qmm
|
||||
greater
|
||||
greater_equal
|
||||
identity
|
||||
@@ -149,11 +150,13 @@ Operations
|
||||
tensordot
|
||||
tile
|
||||
topk
|
||||
trace
|
||||
transpose
|
||||
tri
|
||||
tril
|
||||
triu
|
||||
var
|
||||
view
|
||||
where
|
||||
zeros
|
||||
zeros_like
|
||||
|
166
docs/src/usage/distributed.rst
Normal file
166
docs/src/usage/distributed.rst
Normal file
@@ -0,0 +1,166 @@
|
||||
.. _usage_distributed:
|
||||
|
||||
Distributed Communication
|
||||
=========================
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
||||
provide distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. You can
|
||||
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
||||
|
||||
.. note::
|
||||
A lot of operations may not be supported or not as fast as they should be.
|
||||
We are adding more and tuning the ones we have as we are figuring out the
|
||||
best way to do distributed computing on Macs using MLX.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. The minimal distributed program in MLX is as simple as:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
world = mx.distributed.init()
|
||||
x = mx.distributed.all_sum(mx.ones(10))
|
||||
print(world.rank(), x)
|
||||
|
||||
The program above sums the array ``mx.ones(10)`` across all
|
||||
distributed processes. If simply run with ``python``, however, only one
|
||||
process is launched and no distributed communication takes place.
|
||||
|
||||
To launch the program in distributed mode we need to use ``mpirun`` or
|
||||
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
||||
following:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 python test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
---------------
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
-----------------------
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
||||
full path to force all machines to use a specific path.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
|
||||
.. note::
|
||||
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
||||
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
||||
|
||||
An easy way to pass the host names to MPI is using a host file. A host file
|
||||
looks like the following, where ``host1`` and ``host2`` should be the fully
|
||||
qualified domain names or IPs for these hosts.
|
||||
|
||||
.. code::
|
||||
|
||||
host1 slots=1
|
||||
host2 slots=1
|
||||
|
||||
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||
process per host. The hostfile also needs to contain the current
|
||||
host if you want to run on the local host. Passing the host file to
|
||||
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
||||
|
||||
Training Example
|
||||
----------------
|
||||
|
||||
In this section we will adapt an MLX training loop to support data parallel
|
||||
distributed training. Namely, we will average the gradients across a set of
|
||||
hosts before applying them to the model.
|
||||
|
||||
Our training loop looks like the following code snippet if we omit the model,
|
||||
dataset and optimizer initialization.
|
||||
|
||||
.. code:: python
|
||||
|
||||
model = ...
|
||||
optimizer = ...
|
||||
dataset = ...
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(model, x, y)
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
All we have to do to average the gradients across machines is perform an
|
||||
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
|
||||
have to :func:`mlx.utils.tree_map` the gradients with following function.
|
||||
|
||||
.. code:: python
|
||||
|
||||
def all_avg(x):
|
||||
return mx.distributed.all_sum(x) / mx.distributed.init().size()
|
||||
|
||||
Putting everything together our training loop step looks as follows with
|
||||
everything else remaining the same.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def all_reduce_grads(grads):
|
||||
N = mx.distributed.init()
|
||||
if N == 1:
|
||||
return grads
|
||||
return tree_map(
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads)
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = all_reduce_grads(grads) # <--- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
Tuning All Reduce
|
||||
-----------------
|
||||
|
||||
We are working on improving the performance of all reduce on MLX but for now
|
||||
the two main things one can do to extract the most out of distributed training with MLX are:
|
||||
|
||||
1. Perform a few large reductions instead of many small ones to improve
|
||||
bandwidth and latency
|
||||
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
||||
connections between each host to improve bandwidth
|
@@ -3,7 +3,11 @@
|
||||
Conversion to NumPy and Other Frameworks
|
||||
========================================
|
||||
|
||||
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
MLX array supports conversion between other frameworks with either:
|
||||
|
||||
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||
|
||||
Let's convert an array to NumPy and back.
|
||||
|
||||
.. code-block:: python
|
||||
|
@@ -9,3 +9,4 @@ build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
||||
build_example(metal_capture.cpp)
|
||||
build_example(distributed.cpp)
|
||||
|
22
examples/cpp/distributed.cpp
Normal file
22
examples/cpp/distributed.cpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
if (!distributed::is_available()) {
|
||||
std::cout << "No communication backend found" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto global_group = distributed::init();
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
@@ -21,4 +21,4 @@ python setup.py build_ext -j8 --inplace
|
||||
|
||||
```
|
||||
python test.py
|
||||
`
|
||||
```
|
||||
|
@@ -25,6 +25,7 @@ else()
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
|
||||
if (MLX_BUILD_ACCELERATE)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)
|
||||
|
@@ -32,8 +32,6 @@ DEFAULT(ArgReduce)
|
||||
DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(BlockSparseMM)
|
||||
DEFAULT(BlockSparseQMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
@@ -49,6 +47,8 @@ DEFAULT(ErfInv)
|
||||
DEFAULT(FFT)
|
||||
DEFAULT(Floor)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Less)
|
||||
@@ -80,6 +80,7 @@ DEFAULT(StopGradient)
|
||||
DEFAULT_MULTI(SVD)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
|
||||
void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
@@ -48,6 +48,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
@@ -56,6 +57,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
)
|
||||
|
||||
|
@@ -1,6 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
101
mlx/backend/common/cholesky.cpp
Normal file
101
mlx/backend/common/cholesky.cpp
Normal file
@@ -0,0 +1,101 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#ifdef ACCELERATE_NEW_LAPACK
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#else
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
// Delegate to the Cholesky factorization taking into account differences in
|
||||
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
|
||||
int spotrf_wrapper(char uplo, float* matrix, int N) {
|
||||
int info;
|
||||
|
||||
#ifdef LAPACK_FORTRAN_STRLEN_END
|
||||
spotrf_(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info,
|
||||
/* uplo_len = */ static_cast<size_t>(1));
|
||||
#else
|
||||
spotrf_(
|
||||
/* uplo = */ &uplo,
|
||||
/* n = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
#endif
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cholesky_impl(const array& a, array& factor, bool upper) {
|
||||
// Lapack uses the column-major convention. We take advantage of the fact that
|
||||
// the matrix should be symmetric:
|
||||
// (A)ᵀ = A
|
||||
// and that a column-major lower triangular matrix is a row-major upper
|
||||
// triangular matrix, so uplo is the opposite of what we would expect from
|
||||
// upper
|
||||
|
||||
char uplo = (upper) ? 'L' : 'U';
|
||||
|
||||
// The decomposition is computed in place, so just copy the input to the
|
||||
// output.
|
||||
copy(
|
||||
a,
|
||||
factor,
|
||||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
float* matrix = factor.data<float>();
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
// Compute Cholesky factorization.
|
||||
int info = spotrf_wrapper(uplo, matrix, N);
|
||||
|
||||
// TODO: We do nothing when the matrix is not positive semi-definite
|
||||
// because throwing an error would result in a crash. If we figure out how
|
||||
// to catch errors from the implementation we should throw.
|
||||
if (info < 0) {
|
||||
std::stringstream msg;
|
||||
msg << "[cholesky] Cholesky decomposition failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Zero out the upper/lower triangle while advancing the pointer to the
|
||||
// next matrix at the same time.
|
||||
for (int row = 0; row < N; row++) {
|
||||
if (upper) {
|
||||
std::fill(matrix, matrix + row, 0);
|
||||
} else {
|
||||
std::fill(matrix + row + 1, matrix + N, 0);
|
||||
}
|
||||
matrix += N;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Cholesky::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Cholesky::eval] only supports float32.");
|
||||
}
|
||||
cholesky_impl(inputs[0], output, upper_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -250,49 +250,6 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
|
||||
copy_needed |= strides_[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void Slice::shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
|
@@ -111,13 +111,17 @@ void slow_conv_2D(
|
||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||
const int C = in.shape(3); // In channels
|
||||
const int oH = out.shape(1); // Output spatial dim
|
||||
const int oW = out.shape(2); // Output spatial dim
|
||||
const int O = wt.shape(0); // Out channels
|
||||
const int C = wt.shape(3); // In channels
|
||||
const int wH = wt.shape(1); // Weight spatial dim
|
||||
const int wW = wt.shape(2); // Weight spatial dim
|
||||
|
||||
const int groups = C / wt.shape(3);
|
||||
const int C_per_group = wt.shape(3);
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
const size_t in_stride_N = in.strides()[0];
|
||||
const size_t in_stride_H = in.strides()[1];
|
||||
const size_t in_stride_W = in.strides()[2];
|
||||
@@ -141,33 +145,35 @@ void slow_conv_2D(
|
||||
int ih_base = oh * wt_strides[0] - padding[0];
|
||||
int iw_base = ow * wt_strides[1] - padding[1];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
for (int wh = 0; wh < wH; ++wh) {
|
||||
for (int ww = 0; ww < wW; ++ww) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
|
||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||
@@ -219,41 +225,43 @@ void slow_conv_2D(
|
||||
int wh_base = base_h[oh % f_out_jump_h];
|
||||
int ww_base = base_w[ow % f_out_jump_w];
|
||||
|
||||
for (int o = 0; o < O; ++o) {
|
||||
float r = 0.;
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||
float r = 0.;
|
||||
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||
const T* wt_ptr_pt =
|
||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
const T* in_ptr_pt =
|
||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||
|
||||
for (int c = 0; c < C; ++c) {
|
||||
r += static_cast<float>(in_ptr_pt[0]) *
|
||||
static_cast<float>(wt_ptr_pt[0]);
|
||||
in_ptr_pt += in_stride_C;
|
||||
wt_ptr_pt += wt_stride_C;
|
||||
} // c
|
||||
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||
++c) {
|
||||
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||
static_cast<float>(
|
||||
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||
} // c
|
||||
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
} // ih, iw check
|
||||
} // ww
|
||||
} // wh
|
||||
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
out_ptr[0] = static_cast<T>(r);
|
||||
out_ptr += out_stride_O;
|
||||
wt_ptr += wt_stride_O;
|
||||
} // o
|
||||
} // g
|
||||
};
|
||||
|
||||
int oH_border_0 = 0;
|
||||
|
@@ -256,7 +256,7 @@ void copy_general_general(
|
||||
}
|
||||
|
||||
int size = std::accumulate(
|
||||
data_shape.begin() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
for (int i = 0; i < src.size(); i += size) {
|
||||
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
||||
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
||||
|
@@ -5,7 +5,6 @@
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@@ -43,8 +42,8 @@ DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(BlockSparseMM)
|
||||
DEFAULT(BlockSparseQMM)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
@@ -113,6 +112,7 @@ DEFAULT(Tan)
|
||||
DEFAULT(Tanh)
|
||||
DEFAULT(Transpose)
|
||||
DEFAULT(Inverse)
|
||||
DEFAULT(Cholesky)
|
||||
|
||||
namespace {
|
||||
|
||||
|
@@ -28,6 +28,7 @@ const char* get_kernel_preamble() {
|
||||
return R"preamble(
|
||||
$INCLUDES
|
||||
$CONTENT
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::detail;
|
||||
)preamble";
|
||||
}
|
||||
|
@@ -17,24 +17,25 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename mask_t>
|
||||
inline void mask_matrix(
|
||||
T* data,
|
||||
const bool* mask,
|
||||
const mask_t* mask,
|
||||
int block_size,
|
||||
const int X,
|
||||
const int Y,
|
||||
const size_t X_data_str,
|
||||
const size_t Y_data_str,
|
||||
const size_t X_mask_str,
|
||||
const size_t Y_mask_str) {
|
||||
const size_t Y_mask_str,
|
||||
const size_t mask_offset) {
|
||||
int tX = (X + block_size - 1) / block_size;
|
||||
int tY = (Y + block_size - 1) / block_size;
|
||||
|
||||
for (int i = 0; i < tX; i++) {
|
||||
for (int j = 0; j < tY; j++) {
|
||||
bool do_mask = mask[i * X_mask_str + j * Y_mask_str];
|
||||
if (!do_mask) {
|
||||
mask_t do_mask = mask[mask_offset + i * X_mask_str + j * Y_mask_str];
|
||||
if (do_mask != 1) {
|
||||
int loc_x = i * block_size;
|
||||
int loc_y = j * block_size;
|
||||
T* data_block = data + loc_x * X_data_str + loc_y * Y_data_str;
|
||||
@@ -43,7 +44,11 @@ inline void mask_matrix(
|
||||
int size_y = std::min(block_size, Y - loc_y);
|
||||
for (int ii = 0; ii < size_x; ii++) {
|
||||
for (int jj = 0; jj < size_y; jj++) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
if constexpr (std::is_same_v<mask_t, bool>) {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] = T(0.);
|
||||
} else {
|
||||
data_block[ii * X_data_str + jj * Y_data_str] *= do_mask;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -62,36 +67,39 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
auto& out_mask = inputs[2];
|
||||
|
||||
auto check_transpose = [](const array& arr, bool do_copy) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(true, sty, arr_copy);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
auto check_transpose =
|
||||
[](const array& arr, bool do_copy, bool expand_all = false) {
|
||||
auto stx = arr.strides()[arr.ndim() - 2];
|
||||
auto sty = arr.strides()[arr.ndim() - 1];
|
||||
if (!expand_all && stx == arr.shape(-1) && sty == 1) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
return std::make_tuple(false, stx, arr);
|
||||
} else if (!expand_all && stx == 1 && sty == arr.shape(-2)) {
|
||||
if (do_copy) {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::Vector);
|
||||
return std::make_tuple(true, sty, arr_copy);
|
||||
}
|
||||
return std::make_tuple(true, sty, arr);
|
||||
} else {
|
||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||
copy(arr, arr_copy, CopyType::General);
|
||||
size_t stx = arr.shape(-1);
|
||||
return std::make_tuple(false, stx, arr_copy);
|
||||
}
|
||||
};
|
||||
|
||||
bool has_op_mask = inputs.size() > 3;
|
||||
auto [a_transposed, lda, a] = check_transpose(a_pre, has_op_mask);
|
||||
auto [b_transposed, ldb, b] = check_transpose(b_pre, has_op_mask);
|
||||
bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
|
||||
auto [a_transposed, lda, a] =
|
||||
check_transpose(a_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
auto [b_transposed, ldb, b] =
|
||||
check_transpose(b_pre, has_op_mask, inputs.back().dtype() != bool_);
|
||||
|
||||
size_t M = a.shape(-2);
|
||||
size_t N = b.shape(-1);
|
||||
@@ -114,27 +122,42 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
int Y,
|
||||
size_t X_data_str,
|
||||
size_t Y_data_str) {
|
||||
const bool* mask_ptr = mask.data<bool>() +
|
||||
elem_to_loc(mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
size_t mask_offset = elem_to_loc(
|
||||
mask.shape(-1) * mask.shape(-2) * batch_idx,
|
||||
mask.shape(),
|
||||
mask.strides());
|
||||
|
||||
size_t X_mask_str = mask.strides()[mask.ndim() - 2];
|
||||
size_t Y_mask_str = mask.strides()[mask.ndim() - 1];
|
||||
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask_ptr,
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
X_data_str,
|
||||
Y_data_str,
|
||||
X_mask_str,
|
||||
Y_mask_str);
|
||||
if (mask.dtype() == bool_) {
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask.data<bool>(),
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
X_data_str,
|
||||
Y_data_str,
|
||||
X_mask_str,
|
||||
Y_mask_str,
|
||||
mask_offset);
|
||||
} else {
|
||||
return mask_matrix(
|
||||
data,
|
||||
mask.data<float>(),
|
||||
block_size,
|
||||
X,
|
||||
Y,
|
||||
X_data_str,
|
||||
Y_data_str,
|
||||
X_mask_str,
|
||||
Y_mask_str,
|
||||
mask_offset);
|
||||
}
|
||||
};
|
||||
|
||||
for (int i = 0; i < (a.size() / (M * K)); ++i) {
|
||||
for (int i = 0; i < (out.size() / (M * size_t(N))); ++i) {
|
||||
// Adjust pointer
|
||||
float* ai =
|
||||
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides());
|
||||
@@ -144,7 +167,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
// Zero out blocks in a and b if needed
|
||||
if (has_op_mask) {
|
||||
auto& a_mask = inputs[3];
|
||||
auto& a_mask = inputs[inputs.size() - 2];
|
||||
mask_array(
|
||||
a_mask,
|
||||
ai,
|
||||
@@ -155,7 +178,7 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
a_transposed ? 1 : lda,
|
||||
a_transposed ? lda : 1);
|
||||
|
||||
auto& b_mask = inputs[4];
|
||||
auto& b_mask = inputs[inputs.size() - 1];
|
||||
mask_array(
|
||||
b_mask,
|
||||
bi,
|
||||
@@ -186,14 +209,16 @@ void BlockMaskedMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
);
|
||||
|
||||
// Zero out blocks in out
|
||||
mask_array(out_mask, ci, block_size_, i, M, N, N, 1);
|
||||
if (has_out_mask) {
|
||||
mask_array(inputs[2], ci, block_size_, i, M, N, N, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
void GatherMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
if (out.dtype() != float32) {
|
||||
throw std::runtime_error(
|
||||
"[BlockSparseMM::eval] Currently only supports float32.");
|
||||
"[GatherMM::eval] Currently only supports float32.");
|
||||
}
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
@@ -277,4 +302,4 @@ void BlockSparseMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
} // namespace mlx::core
|
||||
|
@@ -11,6 +11,7 @@
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
@@ -492,7 +493,8 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||
auto [copy_needed, data_offset, inp_strides] =
|
||||
prepare_slice(in, start_indices_, strides_);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
@@ -590,4 +592,36 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
auto obytes = size_of(out.dtype());
|
||||
// Conditions for buffer copying (disjunction):
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < strides.size() - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * obytes / ibytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
copy_inplace(in, tmp, CopyType::General);
|
||||
|
||||
auto flags = out.flags();
|
||||
flags.contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -357,7 +357,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
}
|
||||
|
||||
void BlockSparseQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
void GatherQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 6);
|
||||
|
||||
auto& x_pre = inputs[0];
|
||||
|
@@ -234,7 +234,7 @@ void scan_dispatch(
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::max();
|
||||
: std::numeric_limits<U>::min();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
|
52
mlx/backend/common/slicing.cpp
Normal file
52
mlx/backend/common/slicing.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
std::vector<int>& start_indices,
|
||||
std::vector<int>& strides) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
|
||||
copy_needed |= strides[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
20
mlx/backend/common/slicing.h
Normal file
20
mlx/backend/common/slicing.h
Normal file
@@ -0,0 +1,20 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
std::vector<int>& start_indices,
|
||||
std::vector<int>& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,33 +1,130 @@
|
||||
add_custom_command(
|
||||
OUTPUT compiled_preamble.cpp
|
||||
function(make_jit_source SRC_FILE)
|
||||
# This function takes a metal header file,
|
||||
# runs the C preprocessesor on it, and makes
|
||||
# the processed contents available as a string in a C++ function
|
||||
# mlx::core::metal::${SRC_NAME}()
|
||||
#
|
||||
# To use the function, declare it in jit/includes.h and
|
||||
# include jit/includes.h.
|
||||
#
|
||||
# Additional arguments to this function are treated as dependencies
|
||||
# in the Cmake build system.
|
||||
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
|
||||
add_custom_command(
|
||||
OUTPUT jit/${SRC_NAME}.cpp
|
||||
COMMAND /bin/bash
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit
|
||||
${CMAKE_C_COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${SRC_FILE}
|
||||
"-D${MLX_METAL_VERSION}"
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/compiled_preamble.h
|
||||
kernels/unary.h
|
||||
kernels/binary.h
|
||||
kernels/bf16.h
|
||||
kernels/erf.h
|
||||
kernels/expm1f.h
|
||||
kernels/utils.h
|
||||
kernels/bf16_math.h
|
||||
)
|
||||
kernels/${SRC_FILE}.h
|
||||
${ARGN}
|
||||
)
|
||||
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
|
||||
add_dependencies(mlx ${SRC_NAME})
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}/jit/${SRC_NAME}.cpp
|
||||
)
|
||||
endfunction(make_jit_source)
|
||||
|
||||
add_custom_target(
|
||||
compiled_preamble
|
||||
DEPENDS compiled_preamble.cpp
|
||||
make_jit_source(
|
||||
utils
|
||||
kernels/bf16.h
|
||||
kernels/complex.h
|
||||
kernels/defines.h
|
||||
)
|
||||
make_jit_source(
|
||||
unary_ops
|
||||
kernels/erf.h
|
||||
kernels/expm1f.h
|
||||
)
|
||||
make_jit_source(binary_ops)
|
||||
make_jit_source(ternary_ops)
|
||||
make_jit_source(
|
||||
reduce_utils
|
||||
kernels/atomic.h
|
||||
kernels/reduction/ops.h
|
||||
)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
|
||||
add_dependencies(mlx compiled_preamble)
|
||||
if (MLX_METAL_JIT)
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp
|
||||
)
|
||||
make_jit_source(arange)
|
||||
make_jit_source(copy)
|
||||
make_jit_source(unary)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(
|
||||
fft
|
||||
kernels/fft/radix.h
|
||||
kernels/fft/readwrite.h
|
||||
)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
make_jit_source(sort)
|
||||
make_jit_source(
|
||||
reduce
|
||||
kernels/reduction/reduce_all.h
|
||||
kernels/reduction/reduce_col.h
|
||||
kernels/reduction/reduce_row.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/gemm/gemm
|
||||
kernels/steel/utils.h
|
||||
kernels/steel/gemm/loader.h
|
||||
kernels/steel/gemm/mma.h
|
||||
kernels/steel/gemm/params.h
|
||||
kernels/steel/gemm/transforms.h
|
||||
)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
|
||||
make_jit_source(
|
||||
steel/gemm/kernels/steel_gemm_masked
|
||||
kernels/steel/defines.h
|
||||
)
|
||||
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
|
||||
make_jit_source(
|
||||
steel/conv/conv
|
||||
kernels/steel/utils.h
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/gemm/mma.h
|
||||
kernels/steel/gemm/transforms.h
|
||||
kernels/steel/conv/params.h
|
||||
kernels/steel/conv/loader.h
|
||||
kernels/steel/conv/loaders/loader_channel_l.h
|
||||
kernels/steel/conv/loaders/loader_channel_n.h
|
||||
)
|
||||
make_jit_source(
|
||||
steel/conv/kernels/steel_conv
|
||||
)
|
||||
make_jit_source(
|
||||
steel/conv/kernels/steel_conv_general
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/conv/loaders/loader_general.h
|
||||
)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nojit_kernels.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
@@ -43,10 +140,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
|
360
mlx/backend/metal/binary.cpp
Normal file
360
mlx/backend/metal/binary.cpp
Normal file
@@ -0,0 +1,360 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
// otherwise it goes to the second output
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, 0);
|
||||
compute_encoder.set_input_array(
|
||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
|
||||
compute_encoder.set_output_array(outputs[0], 2);
|
||||
compute_encoder.set_output_array(outputs[1], 3);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 4);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 6);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 7);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op) {
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
||||
auto& strides_a = strides[0];
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
||||
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
if (bopt == BinaryOpType::General) {
|
||||
auto ndim = shape.size();
|
||||
if (ndim > 3) {
|
||||
compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5);
|
||||
} else {
|
||||
// The shape is implicit in the grid for <= 3D
|
||||
compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3);
|
||||
compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 6);
|
||||
}
|
||||
|
||||
// Launch up to 3D grid of threads
|
||||
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
|
||||
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
|
||||
size_t rest = out.size() / (dim0 * dim1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size != 1024) {
|
||||
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
|
||||
}
|
||||
auto group_dims = get_block_dims(dim0, dim1, rest);
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads =
|
||||
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
}
|
||||
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
binary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op) {
|
||||
auto& s = out.primitive().stream();
|
||||
binary_op_gpu(inputs, out, op, s);
|
||||
}
|
||||
|
||||
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "add");
|
||||
}
|
||||
|
||||
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "arctan2");
|
||||
}
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu(inputs, out, "bitwise_and");
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu(inputs, out, "bitwise_or");
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu(inputs, out, "bitwise_xor");
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu(inputs, out, "left_shift");
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu(inputs, out, "right_shift");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "div");
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
binary_op_gpu(inputs, outputs, "divmod");
|
||||
}
|
||||
|
||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "rem");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
|
||||
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "ge");
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "geq");
|
||||
}
|
||||
|
||||
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "le");
|
||||
}
|
||||
|
||||
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "land");
|
||||
}
|
||||
|
||||
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "lor");
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "lae");
|
||||
}
|
||||
|
||||
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "max");
|
||||
}
|
||||
|
||||
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "neq");
|
||||
}
|
||||
|
||||
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "pow");
|
||||
}
|
||||
|
||||
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op_gpu(inputs, out, "sub");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
33
mlx/backend/metal/binary.h
Normal file
33
mlx/backend/metal/binary.h
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@@ -4,8 +4,8 @@
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/compiled_preamble.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
@@ -56,12 +56,15 @@ inline void build_kernel(
|
||||
} else {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl
|
||||
<< " constant const size_t* " << xname << "_strides [[buffer("
|
||||
<< cnt++ << ")]]," << std::endl;
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (add_indices) {
|
||||
os << " constant const size_t* in_strides [[buffer(" << cnt++
|
||||
<< ")]],\n";
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
@@ -110,13 +113,17 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
for (auto& x : inputs) {
|
||||
int nc_in_count = 0;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& x = inputs[i];
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
os << " auto tmp_" << xname << " = static_cast<"
|
||||
<< get_type_string(x.dtype()) << ">(";
|
||||
print_constant(os, x);
|
||||
os << ";" << std::endl;
|
||||
os << ");" << std::endl;
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];" << std::endl;
|
||||
@@ -124,17 +131,20 @@ inline void build_kernel(
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];" << std::endl;
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = nc_in_count * ndim;
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[";
|
||||
os << "index_0 * " << xname << "_strides[0]";
|
||||
os << "index_0 * " << "in_strides[" << offset << "]";
|
||||
for (int i = 1; i < ndim; i++) {
|
||||
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
|
||||
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
|
||||
}
|
||||
os << "];" << std::endl;
|
||||
nc_in_count++;
|
||||
} else {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[elem_to_loc(index, output_shape, " << xname
|
||||
<< "_strides, ndim)];" << std::endl;
|
||||
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
|
||||
<< nc_in_count * ndim << ", ndim)];" << std::endl;
|
||||
nc_in_count++;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +200,8 @@ void Compiled::eval_gpu(
|
||||
// If not we have to build it ourselves
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel;
|
||||
kernel << metal::get_kernel_preamble() << std::endl;
|
||||
kernel << metal::utils() << metal::unary_ops() << metal::binary_ops()
|
||||
<< metal::ternary_ops();
|
||||
build_kernel(
|
||||
kernel,
|
||||
kernel_lib_ + "_contiguous",
|
||||
@@ -295,6 +306,7 @@ void Compiled::eval_gpu(
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
std::vector<size_t> in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
@@ -302,13 +314,17 @@ void Compiled::eval_gpu(
|
||||
auto& x = inputs[i];
|
||||
compute_encoder.set_input_array(x, cnt++);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
strides[stride_idx].size() * sizeof(size_t),
|
||||
cnt++);
|
||||
in_strides.insert(
|
||||
in_strides.end(),
|
||||
strides[stride_idx].begin(),
|
||||
strides[stride_idx].end());
|
||||
stride_idx++;
|
||||
}
|
||||
}
|
||||
if (!in_strides.empty()) {
|
||||
compute_encoder->setBytes(
|
||||
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
|
@@ -1,9 +0,0 @@
|
||||
// Copyright © 2023-24 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* get_kernel_preamble();
|
||||
|
||||
}
|
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
||||
#include "mlx/backend/metal/matmul.h"
|
||||
@@ -257,15 +258,19 @@ void implicit_gemm_conv_2D_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
const int groups = conv_params.groups;
|
||||
const int C_per_group = conv_params.C / conv_params.groups;
|
||||
const int O_per_group = conv_params.O / conv_params.groups;
|
||||
|
||||
// Deduce implicit gemm size
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||
const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
const int implicit_N = O_per_group;
|
||||
const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
|
||||
|
||||
// Determine block and warp tiles
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
|
||||
int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
|
||||
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
|
||||
int bk = 16;
|
||||
|
||||
@@ -281,15 +286,15 @@ void implicit_gemm_conv_2D_gpu(
|
||||
|
||||
// Fix small channel specialization
|
||||
int n_channel_specialization = 0;
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int channel_k_iters = ((C_per_group + bk - 1) / bk);
|
||||
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
|
||||
|
||||
if (conv_params.C <= 2) {
|
||||
if (C_per_group <= 2) {
|
||||
gemm_k_iters = (implicit_K + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
} else if (conv_params.C <= 4) {
|
||||
n_channel_specialization = C_per_group;
|
||||
} else if (C_per_group <= 4) {
|
||||
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
n_channel_specialization = C_per_group;
|
||||
}
|
||||
|
||||
bool small_filter = (!n_channel_specialization) &&
|
||||
@@ -331,7 +336,17 @@ void implicit_gemm_conv_2D_gpu(
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = get_steel_conv_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
out,
|
||||
bm,
|
||||
bn,
|
||||
bk,
|
||||
wm,
|
||||
wn,
|
||||
n_channel_specialization,
|
||||
small_filter);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@@ -340,7 +355,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
size_t grid_dim_x = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
|
||||
|
||||
// Encode arrays
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -484,7 +499,8 @@ void implicit_gemm_conv_2D_general_gpu(
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel =
|
||||
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Deduce grid launch dimensions
|
||||
@@ -703,6 +719,7 @@ void conv_2D_gpu(
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
const int groups,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
@@ -718,12 +735,12 @@ void conv_2D_gpu(
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
|
||||
{in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
{wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
/* const int groups = */ 1,
|
||||
{out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
|
||||
@@ -735,6 +752,18 @@ void conv_2D_gpu(
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
||||
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
||||
|
||||
if (groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
@@ -860,6 +889,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
|
@@ -4,12 +4,14 @@
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
@@ -31,9 +33,6 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
@@ -55,6 +54,10 @@ void copy_gpu_inplace(
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector{strides_in_pre, strides_out_pre});
|
||||
@@ -62,27 +65,34 @@ void copy_gpu_inplace(
|
||||
auto& strides_out_ = strides[1];
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << "scopy";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << "vcopy";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "gcopy";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "ggcopy";
|
||||
break;
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << "s";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << "v";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "g";
|
||||
break;
|
||||
case CopyType::GeneralGeneral:
|
||||
kname << "gg";
|
||||
break;
|
||||
}
|
||||
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
||||
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
}
|
||||
kname << "_copy";
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
kname << type_to_name(in) << type_to_name(out);
|
||||
if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) &&
|
||||
shape.size() <= MAX_COPY_SPECIALIZED_DIMS) {
|
||||
kname << "_" << shape.size();
|
||||
}
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
auto kernel = get_copy_kernel(d, kernel_name, in, out);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_in = in.data_shared_ptr() == nullptr;
|
||||
@@ -106,7 +116,7 @@ void copy_gpu_inplace(
|
||||
set_vector_bytes(compute_encoder, strides_out, ndim, 4);
|
||||
}
|
||||
|
||||
if (ndim > MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
if (ndim > MAX_COPY_SPECIALIZED_DIMS) {
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 5);
|
||||
}
|
||||
|
||||
|
@@ -285,7 +285,6 @@ MTL::Library* Device::get_library_(const std::string& source_string) {
|
||||
NS::Error* error = nullptr;
|
||||
auto options = MTL::CompileOptions::alloc()->init();
|
||||
options->setFastMathEnabled(false);
|
||||
|
||||
options->setLanguageVersion(get_metal_version());
|
||||
auto mtl_lib = device_->newLibrary(ns_code, options, &error);
|
||||
options->release();
|
||||
|
@@ -63,7 +63,7 @@ struct CommandEncoder {
|
||||
return enc;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int offset = 0) {
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0) {
|
||||
auto r_buf =
|
||||
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
@@ -80,7 +80,7 @@ struct CommandEncoder {
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void set_output_array(array& a, int idx, int offset = 0) {
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
|
@@ -1,106 +1,794 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <complex>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/backend/metal/binary.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/unary.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
|
||||
|
||||
auto& in = inputs[0];
|
||||
#define MAX_STOCKHAM_FFT_SIZE 4096
|
||||
#define MAX_RADER_FFT_SIZE 2048
|
||||
#define MAX_BLUESTEIN_FFT_SIZE 2048
|
||||
// Threadgroup memory batching improves throughput for small n
|
||||
#define MIN_THREADGROUP_MEM_SIZE 256
|
||||
// For strided reads/writes, coalesce at least this many complex64s
|
||||
#define MIN_COALESCE_WIDTH 4
|
||||
|
||||
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
|
||||
in.dtype() != complex64 || out.dtype() != complex64) {
|
||||
// Could also fallback to CPU implementation here.
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
|
||||
inline const std::vector<int> supported_radices() {
|
||||
// Ordered by preference in decomposition.
|
||||
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
|
||||
}
|
||||
|
||||
std::vector<int> prime_factors(int n) {
|
||||
int z = 2;
|
||||
std::vector<int> factors;
|
||||
while (z * z <= n) {
|
||||
if (n % z == 0) {
|
||||
factors.push_back(z);
|
||||
n /= z;
|
||||
} else {
|
||||
z++;
|
||||
}
|
||||
}
|
||||
if (n > 1) {
|
||||
factors.push_back(n);
|
||||
}
|
||||
return factors;
|
||||
}
|
||||
|
||||
struct FourStepParams {
|
||||
bool required = false;
|
||||
bool first_step = true;
|
||||
int n1 = 0;
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
// Forward Declaration
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
const Stream& s);
|
||||
|
||||
struct FFTPlan {
|
||||
int n = 0;
|
||||
// Number of steps for each radix in the Stockham decomposition
|
||||
std::vector<int> stockham;
|
||||
// Number of steps for each radix in the Rader decomposition
|
||||
std::vector<int> rader;
|
||||
// Rader factor, 1 if no rader factors
|
||||
int rader_n = 1;
|
||||
int bluestein_n = -1;
|
||||
// Four step FFT
|
||||
bool four_step = false;
|
||||
int n1 = 0;
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
int next_fast_n(int n) {
|
||||
return next_power_of_2(n);
|
||||
}
|
||||
|
||||
std::vector<int> plan_stockham_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::vector<int> plan(radices.size(), 0);
|
||||
int orig_n = n;
|
||||
if (n == 1) {
|
||||
return plan;
|
||||
}
|
||||
for (int i = 0; i < radices.size(); i++) {
|
||||
int radix = radices[i];
|
||||
// Manually tuned radices for powers of 2
|
||||
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||
continue;
|
||||
}
|
||||
while (n % radix == 0) {
|
||||
plan[i] += 1;
|
||||
n /= radix;
|
||||
if (n == 1) {
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Unplannable");
|
||||
}
|
||||
|
||||
FFTPlan plan_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> radices_set(radices.begin(), radices.end());
|
||||
|
||||
FFTPlan plan;
|
||||
plan.n = n;
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
auto factors = prime_factors(n);
|
||||
int remaining_n = n;
|
||||
|
||||
// Four Step FFT when N is too large for shared mem.
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
// For power's of two we have a fast, no transpose four step implementation.
|
||||
plan.four_step = true;
|
||||
// Rough heuristic for choosing faster powers of two when we can
|
||||
plan.n2 = n > 65536 ? 1024 : 64;
|
||||
plan.n1 = n / plan.n2;
|
||||
return plan;
|
||||
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
// Otherwise we use a multi-upload Bluestein's
|
||||
plan.four_step = true;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
return plan;
|
||||
}
|
||||
|
||||
size_t n = in.shape(axes_[0]);
|
||||
for (int factor : factors) {
|
||||
// Make sure the factor is a supported radix
|
||||
if (radices_set.find(factor) == radices_set.end()) {
|
||||
// We only support a single Rader factor currently
|
||||
// TODO(alexbarron) investigate weirdness with large
|
||||
// Rader sizes -- possibly a compiler issue?
|
||||
if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
return plan;
|
||||
}
|
||||
// See if we can use Rader's algorithm to Stockham decompose n - 1
|
||||
auto rader_factors = prime_factors(factor - 1);
|
||||
int last_factor = -1;
|
||||
for (int rf : rader_factors) {
|
||||
// We don't nest Rader's algorithm so if `factor - 1`
|
||||
// isn't Stockham decomposable we give up and do Bluestein's.
|
||||
if (radices_set.find(rf) == radices_set.end()) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
plan.rader = plan_stockham_fft(factor - 1);
|
||||
plan.rader_n = factor;
|
||||
remaining_n /= factor;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_power_of_2(n) || n > 2048 || n < 4) {
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
|
||||
plan.stockham = plan_stockham_fft(remaining_n);
|
||||
return plan;
|
||||
}
|
||||
|
||||
int compute_elems_per_thread(FFTPlan plan) {
|
||||
// Heuristics for selecting an efficient number
|
||||
// of threads to use for a particular mixed-radix FFT.
|
||||
auto n = plan.n;
|
||||
|
||||
std::vector<int> steps;
|
||||
auto radices = supported_radices();
|
||||
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
|
||||
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
|
||||
std::set<int> used_radices;
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
int radix = radices[i % radices.size()];
|
||||
if (steps[i] > 0) {
|
||||
used_radices.insert(radix);
|
||||
}
|
||||
}
|
||||
|
||||
// Manual tuning for 7/11/13
|
||||
if (used_radices.find(7) != used_radices.end() &&
|
||||
(used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end())) {
|
||||
return 7;
|
||||
} else if (
|
||||
used_radices.find(11) != used_radices.end() &&
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return 11;
|
||||
}
|
||||
|
||||
// TODO(alexbarron) Some really weird stuff is going on
|
||||
// for certain `elems_per_thread` on large composite n.
|
||||
// Possibly a compiler issue?
|
||||
if (n == 3159)
|
||||
return 13;
|
||||
if (n == 3645)
|
||||
return 5;
|
||||
if (n == 3969)
|
||||
return 7;
|
||||
if (n == 1982)
|
||||
return 5;
|
||||
|
||||
if (used_radices.size() == 1) {
|
||||
return *(used_radices.begin());
|
||||
}
|
||||
if (used_radices.size() == 2) {
|
||||
if (used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
|
||||
}
|
||||
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
|
||||
return radix_vec[1];
|
||||
}
|
||||
// In all other cases use the second smallest radix.
|
||||
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
|
||||
return radix_vec[1];
|
||||
}
|
||||
|
||||
// Rader
|
||||
int mod_exp(int x, int y, int n) {
|
||||
int out = 1;
|
||||
while (y) {
|
||||
if (y & 1) {
|
||||
out = out * x % n;
|
||||
}
|
||||
y >>= 1;
|
||||
x = x * x % n;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
int primitive_root(int n) {
|
||||
auto factors = prime_factors(n - 1);
|
||||
|
||||
for (int r = 2; r < n - 1; r++) {
|
||||
bool found = true;
|
||||
for (int factor : factors) {
|
||||
if (mod_exp(r, (n - 1) / factor, n) == 1) {
|
||||
found = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> compute_raders_constants(
|
||||
int rader_n,
|
||||
const Stream& s) {
|
||||
int proot = primitive_root(rader_n);
|
||||
// Fermat's little theorem
|
||||
int inv = mod_exp(proot, rader_n - 2, rader_n);
|
||||
std::vector<short> g_q(rader_n - 1);
|
||||
std::vector<short> g_minus_q(rader_n - 1);
|
||||
for (int i = 0; i < rader_n - 1; i++) {
|
||||
g_q[i] = mod_exp(proot, i, rader_n);
|
||||
g_minus_q[i] = mod_exp(inv, i, rader_n);
|
||||
}
|
||||
array g_q_arr(g_q.begin(), {rader_n - 1});
|
||||
array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1});
|
||||
|
||||
std::vector<std::complex<float>> b_q(rader_n - 1);
|
||||
for (int i = 0; i < rader_n - 1; i++) {
|
||||
float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n;
|
||||
b_q[i] = std::exp(std::complex<float>(0, pi_i));
|
||||
}
|
||||
|
||||
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
|
||||
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
|
||||
auto b_q_fft_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
|
||||
std::ptrdiff_t item_size = b_q_fft.itemsize();
|
||||
size_t fft_size = rader_n - 1;
|
||||
// This FFT is always small (<4096, batch 1) so save some overhead
|
||||
// and do it on the CPU
|
||||
pocketfft::c2c(
|
||||
/* shape= */ {fft_size},
|
||||
/* stride_in= */ {item_size},
|
||||
/* stride_out= */ {item_size},
|
||||
/* axes= */ {0},
|
||||
/* forward= */ true,
|
||||
/* data_in= */ b_q.data(),
|
||||
/* data_out= */ b_q_fft_ptr,
|
||||
/* scale= */ 1.0f);
|
||||
return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr);
|
||||
}
|
||||
|
||||
// Bluestein
|
||||
std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
|
||||
// We need to calculate the Bluestein twiddle factors
|
||||
// in double precision for the overall numerical stability
|
||||
// of Bluestein's FFT algorithm to be acceptable.
|
||||
//
|
||||
// Metal doesn't support float64, so instead we
|
||||
// manually implement the required operations on cpu.
|
||||
//
|
||||
// In numpy:
|
||||
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
|
||||
// w_q = np.fft.fft(1/w_k)
|
||||
// return w_k, w_q
|
||||
int length = 2 * n - 1;
|
||||
|
||||
std::vector<std::complex<float>> w_k_vec(n);
|
||||
std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
|
||||
|
||||
for (int i = -n + 1; i < n; i++) {
|
||||
double theta = pow(i, 2) * M_PI / (double)n;
|
||||
w_q_vec[i + n - 1] = std::exp(std::complex<double>(0, theta));
|
||||
if (i >= 0) {
|
||||
w_k_vec[i] = std::exp(std::complex<double>(0, -theta));
|
||||
}
|
||||
}
|
||||
|
||||
array w_k({n}, complex64, nullptr, {});
|
||||
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
|
||||
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
|
||||
|
||||
array w_q({bluestein_n}, complex64, nullptr, {});
|
||||
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
|
||||
auto w_q_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
|
||||
|
||||
std::ptrdiff_t item_size = w_q.itemsize();
|
||||
size_t fft_size = bluestein_n;
|
||||
pocketfft::c2c(
|
||||
/* shape= */ {fft_size},
|
||||
/* stride_in= */ {item_size},
|
||||
/* stride_out= */ {item_size},
|
||||
/* axes= */ {0},
|
||||
/* forward= */ true,
|
||||
/* data_in= */ w_q_vec.data(),
|
||||
/* data_out= */ w_q_ptr,
|
||||
/* scale= */ 1.0f);
|
||||
return std::make_tuple(w_k, w_q);
|
||||
}
|
||||
|
||||
void multi_upload_bluestein_fft(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
const Stream& s) {
|
||||
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
||||
// algorithm
|
||||
int n = inverse ? out.shape(axis) : in.shape(axis);
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
|
||||
// Broadcast w_q and w_k to the batch size
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast({}, complex64, nullptr, {});
|
||||
array w_q_broadcast({}, complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
|
||||
|
||||
auto temp_shape = inverse ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
|
||||
if (real && !inverse) {
|
||||
// Convert float32->complex64
|
||||
copy_gpu(in, temp, CopyType::General, s);
|
||||
} else if (real && inverse) {
|
||||
int back_offset = n % 2 == 0 ? 2 : 1;
|
||||
auto slice_shape = in.shape();
|
||||
slice_shape[axis] -= back_offset;
|
||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||
copies.push_back(slice_temp);
|
||||
copies.push_back(conj_temp);
|
||||
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in}, conj_temp, "conj", s);
|
||||
slice_gpu(in, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in}, temp, "conj", s);
|
||||
} else {
|
||||
temp.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
binary_op_gpu({temp, w_k_broadcast}, temp1, "mul", s);
|
||||
|
||||
std::vector<std::pair<int, int>> pads;
|
||||
auto padded_shape = out.shape();
|
||||
padded_shape[axis] = plan.bluestein_n;
|
||||
array pad_temp(padded_shape, complex64, nullptr, {});
|
||||
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
|
||||
|
||||
array pad_temp1(padded_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
pad_temp,
|
||||
pad_temp1,
|
||||
axis,
|
||||
/*inverse=*/false,
|
||||
/*real=*/false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/false,
|
||||
s);
|
||||
|
||||
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "mul", s);
|
||||
|
||||
fft_op(
|
||||
pad_temp,
|
||||
pad_temp1,
|
||||
axis,
|
||||
/* inverse= */ true,
|
||||
/* real= */ false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/true,
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
std::vector<int> strides(in.ndim(), 1);
|
||||
starts[axis] = plan.bluestein_n - offset - n;
|
||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "mul", s);
|
||||
|
||||
if (real && !inverse) {
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
slice_gpu(temp1, out, rstarts, strides, s);
|
||||
} else if (real && inverse) {
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
auto inv_n = array({1.0f / n}, {1}, float32);
|
||||
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
||||
copies.push_back(temp_float);
|
||||
copies.push_back(inv_n);
|
||||
|
||||
copy_gpu(temp1, temp_float, CopyType::General, s);
|
||||
binary_op_gpu({temp_float, inv_n}, out, "mul", s);
|
||||
} else if (inverse) {
|
||||
auto inv_n = array({1.0f / n}, {1}, complex64);
|
||||
unary_op_gpu({temp1}, temp, "conj", s);
|
||||
binary_op_gpu({temp, inv_n}, out, "mul", s);
|
||||
copies.push_back(inv_n);
|
||||
} else {
|
||||
out.copy_shared_buffer(temp1);
|
||||
}
|
||||
|
||||
copies.push_back(w_k);
|
||||
copies.push_back(w_q);
|
||||
copies.push_back(w_k_broadcast);
|
||||
copies.push_back(w_q_broadcast);
|
||||
copies.push_back(temp);
|
||||
copies.push_back(temp1);
|
||||
copies.push_back(pad_temp);
|
||||
copies.push_back(pad_temp1);
|
||||
}
|
||||
|
||||
void four_step_fft(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
if (plan.bluestein_n == -1) {
|
||||
// Fast no transpose implementation for powers of 2.
|
||||
FourStepParams four_step_params = {
|
||||
/* required= */ true, /* first_step= */ true, plan.n1, plan.n2};
|
||||
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
four_step_params.first_step = false;
|
||||
fft_op(
|
||||
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
}
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
||||
if (n == 1) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
if (four_step_params.required) {
|
||||
// Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows
|
||||
n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2;
|
||||
}
|
||||
|
||||
// Make sure that the array is contiguous and has stride 1 in the FFT dim
|
||||
std::vector<array> copies;
|
||||
auto check_input = [this, &copies, &s](const array& x) {
|
||||
auto check_input = [&axis, &copies, &s](const array& x) {
|
||||
// TODO: Pass the strides to the kernel so
|
||||
// we can avoid the copy when x is not contiguous.
|
||||
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
|
||||
x.flags().col_contiguous;
|
||||
bool no_copy = x.strides()[axis] == 1 &&
|
||||
(x.flags().row_contiguous || x.flags().col_contiguous);
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axes_[0]);
|
||||
for (int axis = 0; axis < x.ndim(); axis++) {
|
||||
if (axis == axes_[0]) {
|
||||
size_t cur_stride = x.shape(axis);
|
||||
for (int a = 0; a < x.ndim(); a++) {
|
||||
if (a == axis) {
|
||||
strides.push_back(1);
|
||||
} else {
|
||||
strides.push_back(cur_stride);
|
||||
cur_stride *= x.shape(axis);
|
||||
cur_stride *= x.shape(a);
|
||||
}
|
||||
}
|
||||
|
||||
auto flags = x.flags();
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
|
||||
f_stride *= x.shape(i);
|
||||
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
|
||||
b_stride *= x.shape(ri);
|
||||
}
|
||||
// This is probably over-conservative
|
||||
flags.contiguous = false;
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(x.shape(), strides);
|
||||
|
||||
flags.col_contiguous = is_row_contiguous;
|
||||
flags.row_contiguous = is_col_contiguous;
|
||||
flags.contiguous = data_size == x_copy.size();
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
|
||||
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(inputs[0]);
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
// real to complex: n -> (n/2)+1
|
||||
// complex to real: (n/2)+1 -> n
|
||||
auto out_strides = in_contiguous.strides();
|
||||
size_t out_data_size = in_contiguous.data_size();
|
||||
if (in.shape(axis) != out.shape(axis)) {
|
||||
for (int i = 0; i < out_strides.size(); i++) {
|
||||
if (out_strides[i] != 1) {
|
||||
out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis);
|
||||
}
|
||||
}
|
||||
out_data_size = out_data_size / in.shape(axis) * out.shape(axis);
|
||||
}
|
||||
|
||||
auto plan = plan_fft(n);
|
||||
if (plan.four_step) {
|
||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: allow donation here
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
in_contiguous.data_size(),
|
||||
in_contiguous.strides(),
|
||||
in_contiguous.flags());
|
||||
if (!inplace) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
out_data_size,
|
||||
out_strides,
|
||||
in_contiguous.flags());
|
||||
}
|
||||
|
||||
// We use n / 4 threads by default since radix-4
|
||||
// is the largest single threaded radix butterfly
|
||||
// we currently implement.
|
||||
size_t m = n / 4;
|
||||
size_t batch = in.size() / in.shape(axes_[0]);
|
||||
auto radices = supported_radices();
|
||||
int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n;
|
||||
|
||||
// Setup function constants
|
||||
bool power_of_2 = is_power_of_2(fft_size);
|
||||
|
||||
auto make_int = [](int* a, int i) {
|
||||
return std::make_tuple(a, MTL::DataType::DataTypeInt, i);
|
||||
};
|
||||
auto make_bool = [](bool* a, int i) {
|
||||
return std::make_tuple(a, MTL::DataType::DataTypeBool, i);
|
||||
};
|
||||
|
||||
std::vector<MTLFC> func_consts = {
|
||||
make_bool(&inverse, 0), make_bool(&power_of_2, 1)};
|
||||
|
||||
// Start of radix/rader step constants
|
||||
int index = 4;
|
||||
for (int i = 0; i < plan.stockham.size(); i++) {
|
||||
func_consts.push_back(make_int(&plan.stockham[i], index));
|
||||
index += 1;
|
||||
}
|
||||
for (int i = 0; i < plan.rader.size(); i++) {
|
||||
func_consts.push_back(make_int(&plan.rader[i], index));
|
||||
index += 1;
|
||||
}
|
||||
int elems_per_thread = compute_elems_per_thread(plan);
|
||||
func_consts.push_back(make_int(&elems_per_thread, 2));
|
||||
|
||||
int rader_m = n / plan.rader_n;
|
||||
func_consts.push_back(make_int(&rader_m, 3));
|
||||
|
||||
// The overall number of FFTs we're going to compute for this input
|
||||
int size = out.dtype() == float32 ? out.size() : in.size();
|
||||
if (real && inverse && four_step_params.required) {
|
||||
size = out.size();
|
||||
}
|
||||
int total_batch_size = size / n;
|
||||
int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread;
|
||||
|
||||
// We batch among threadgroups for improved efficiency when n is small
|
||||
int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1);
|
||||
if (four_step_params.required) {
|
||||
// Require a threadgroup batch size of at least 4 for four step FFT
|
||||
// so we can coalesce the memory accesses.
|
||||
threadgroup_batch_size =
|
||||
std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH);
|
||||
}
|
||||
int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size);
|
||||
// FFTs up to 2^20 are currently supported
|
||||
assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE);
|
||||
|
||||
// ceil divide
|
||||
int batch_size =
|
||||
(total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size;
|
||||
|
||||
if (real && !four_step_params.required) {
|
||||
// We can perform 2 RFFTs at once so the batch size is halved.
|
||||
batch_size = (batch_size + 2 - 1) / 2;
|
||||
}
|
||||
int out_buffer_size = out.size();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
|
||||
// Only required by four step
|
||||
int step = -1;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "fft_" << n;
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
std::string inv_string = inverse ? "true" : "false";
|
||||
std::string real_string = real ? "true" : "false";
|
||||
if (plan.bluestein_n > 0) {
|
||||
kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
|
||||
<< in_type_str << "_" << out_type_str;
|
||||
} else if (plan.rader_n > 1) {
|
||||
kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
|
||||
<< "_" << out_type_str;
|
||||
} else if (four_step_params.required) {
|
||||
step = four_step_params.first_step ? 0 : 1;
|
||||
kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
|
||||
<< "_" << out_type_str << "_" << step << "_" << real_string;
|
||||
} else {
|
||||
kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
|
||||
<< out_type_str;
|
||||
}
|
||||
std::string base_name = kname.str();
|
||||
// We use a specialized kernel for each FFT size
|
||||
kname << "_n" << fft_size << "_inv_" << inverse;
|
||||
std::string hash_name = kname.str();
|
||||
auto kernel = get_fft_kernel(
|
||||
d,
|
||||
base_name,
|
||||
hash_name,
|
||||
threadgroup_mem_size,
|
||||
in_type_str,
|
||||
out_type_str,
|
||||
step,
|
||||
real,
|
||||
func_consts);
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
auto group_dims = MTL::Size(1, m, 1);
|
||||
auto grid_dims = MTL::Size(batch, m, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
if (plan.bluestein_n > 0) {
|
||||
// Precomputed twiddle factors for Bluestein's
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
copies.push_back(w_q);
|
||||
copies.push_back(w_k);
|
||||
|
||||
compute_encoder.set_input_array(w_q, 2); // w_q
|
||||
compute_encoder.set_input_array(w_k, 3); // w_k
|
||||
compute_encoder->setBytes(&n, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
} else if (plan.rader_n > 1) {
|
||||
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
|
||||
copies.push_back(b_q);
|
||||
copies.push_back(g_q);
|
||||
copies.push_back(g_minus_q);
|
||||
|
||||
compute_encoder.set_input_array(b_q, 2);
|
||||
compute_encoder.set_input_array(g_q, 3);
|
||||
compute_encoder.set_input_array(g_minus_q, 4);
|
||||
compute_encoder->setBytes(&n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
|
||||
} else if (four_step_params.required) {
|
||||
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(&n, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
|
||||
}
|
||||
|
||||
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
|
||||
auto grid_dims =
|
||||
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
bool inplace,
|
||||
const Stream& s) {
|
||||
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
|
||||
}
|
||||
|
||||
void nd_fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const Stream& s) {
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
auto temp_shape = inverse ? in.shape() : out.shape();
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
bool inplace = reverse_index >= 3 && i != 0;
|
||||
// Opposite order for fft vs ifft
|
||||
int index = inverse ? reverse_index : i;
|
||||
size_t axis = axes[index];
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
|
||||
std::vector<array> copies = {temp1, temp2};
|
||||
auto& d = metal::device(s.device);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (axes_.size() > 1) {
|
||||
nd_fft_op(in, out, axes_, inverse_, real_, s);
|
||||
} else {
|
||||
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,24 +1,35 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/indexing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
constexpr int METAL_MAX_INDEX_ARRAYS = 20;
|
||||
|
||||
constexpr int METAL_MAX_INDEX_ARRAYS = 10;
|
||||
|
||||
} // namespace
|
||||
std::pair<std::string, std::string> make_index_args(
|
||||
const std::string& idx_type,
|
||||
int nidx) {
|
||||
std::ostringstream idx_args;
|
||||
std::ostringstream idx_arr;
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_args << fmt::format(
|
||||
"const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
|
||||
idx_arr << fmt::format("idx{0}", i);
|
||||
if (i < nidx - 1) {
|
||||
idx_args << "\n";
|
||||
idx_arr << ",";
|
||||
}
|
||||
}
|
||||
return {idx_args.str(), idx_arr.str()};
|
||||
}
|
||||
|
||||
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& src = inputs[0];
|
||||
@@ -42,15 +53,41 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
size_t ndim = src.ndim();
|
||||
|
||||
std::ostringstream kname;
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx;
|
||||
if (idx_ndim <= 1) {
|
||||
kname << "_" << idx_ndim;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "gather" << type_to_name(out) << idx_type_name << "_" << nidx
|
||||
<< "_" << idx_ndim;
|
||||
lib_name = kname.str();
|
||||
kernel_name = lib_name;
|
||||
}
|
||||
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gather();
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
// Index dimension specializations
|
||||
kernel_source << fmt::format(
|
||||
gather_kernels,
|
||||
type_to_name(out) + idx_type_name,
|
||||
out_type_str,
|
||||
idx_type_str,
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr,
|
||||
idx_ndim);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
size_t slice_size = 1;
|
||||
@@ -102,8 +139,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 9);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
@@ -139,10 +176,6 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Get kernel name
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
bool index_nd1_specialization = (idx_ndim == 1);
|
||||
|
||||
@@ -159,32 +192,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
||||
}
|
||||
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
std::string lib_name;
|
||||
std::string kernel_name;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
std::string op_name;
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
kname << "_none";
|
||||
op_name = "none";
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
kname << "_sum";
|
||||
op_name = "sum";
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
kname << "_prod";
|
||||
op_name = "prod";
|
||||
break;
|
||||
case Scatter::Max:
|
||||
kname << "_max";
|
||||
op_name = "max";
|
||||
break;
|
||||
case Scatter::Min:
|
||||
kname << "_min";
|
||||
op_name = "min";
|
||||
break;
|
||||
}
|
||||
kname << "_" << nidx;
|
||||
|
||||
{
|
||||
std::ostringstream kname;
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
kname << "_" << op_name << "_" << nidx;
|
||||
lib_name = kname.str();
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< metal::scatter();
|
||||
|
||||
std::string out_type_str = get_type_string(out.dtype());
|
||||
std::string idx_type_str =
|
||||
nidx ? get_type_string(inputs[1].dtype()) : "bool";
|
||||
std::string op_type;
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
op_type = "None";
|
||||
break;
|
||||
case Scatter::Sum:
|
||||
op_type = "Sum<{0}>";
|
||||
break;
|
||||
case Scatter::Prod:
|
||||
op_type = "Prod<{0}>";
|
||||
break;
|
||||
case Scatter::Max:
|
||||
op_type = "Max<{0}>";
|
||||
break;
|
||||
case Scatter::Min:
|
||||
op_type = "Min<{0}>";
|
||||
break;
|
||||
}
|
||||
if (reduce_type_ != Scatter::None) {
|
||||
op_type = fmt::format(op_type, out_type_str);
|
||||
}
|
||||
auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
|
||||
|
||||
kernel_source << fmt::format(
|
||||
scatter_kernels,
|
||||
type_to_name(out) + idx_type_name + "_" + op_name,
|
||||
out_type_str,
|
||||
idx_type_str,
|
||||
op_type,
|
||||
nidx,
|
||||
idx_args,
|
||||
idx_arr);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
|
||||
auto& upd = inputs.back();
|
||||
size_t nthreads = upd.size();
|
||||
@@ -209,8 +296,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
@@ -279,8 +366,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i], 20 + i);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
compute_encoder.set_input_array(inputs[i + 1], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
|
9
mlx/backend/metal/jit/arange.h
Normal file
9
mlx/backend/metal/jit/arange.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view arange_kernels = R"(
|
||||
template [[host_name("{0}")]] [[kernel]] void arange<{1}>(
|
||||
constant const {1}& start,
|
||||
constant const {1}& step,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
87
mlx/backend/metal/jit/binary.h
Normal file
87
mlx/backend/metal/jit/binary.h
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
98
mlx/backend/metal/jit/binary_two.h
Normal file
98
mlx/backend/metal/jit/binary_two.h
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_two_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
100
mlx/backend/metal/jit/copy.h
Normal file
100
mlx/backend/metal/jit/copy.h
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view copy_kernels = R"(
|
||||
template [[host_name("s_{0}")]] [[kernel]] void copy_s<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("v_{0}")]] [[kernel]] void copy_v<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
copy_g_nd<{1}, {2}, 4>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg4_{0}")]] [[kernel]] void
|
||||
copy_gg_nd<{1}, {2}, 4>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
copy_g_nd<{1}, {2}, 5>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg5_{0}")]] [[kernel]] void
|
||||
copy_gg_nd<{1}, {2}, 5>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void copy_g_nd1<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void copy_g_nd2<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void copy_g_nd3<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg1_{0}")]] [[kernel]] void
|
||||
copy_gg_nd1<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("gg2_{0}")]] [[kernel]] void
|
||||
copy_gg_nd2<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]);
|
||||
template [[host_name("gg3_{0}")]] [[kernel]] void
|
||||
copy_gg_nd3<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void copy_g<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("gg_{0}")]] [[kernel]] void copy_gg<{1}, {2}>(
|
||||
device const {1}* src [[buffer(0)]],
|
||||
device {2}* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
)";
|
53
mlx/backend/metal/jit/fft.h
Normal file
53
mlx/backend/metal/jit/fft.h
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view fft_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
fft<{tg_mem_size}, {in_T}, {out_T}>(
|
||||
const device {in_T}* in [[buffer(0)]],
|
||||
device {out_T}* out [[buffer(1)]],
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view rader_fft_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
rader_fft<{tg_mem_size}, {in_T}, {out_T}>(
|
||||
const device {in_T}* in [[buffer(0)]],
|
||||
device {out_T}* out [[buffer(1)]],
|
||||
const device float2* raders_b_q [[buffer(2)]],
|
||||
const device short* raders_g_q [[buffer(3)]],
|
||||
const device short* raders_g_minus_q [[buffer(4)]],
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
constant const int& rader_n,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view bluestein_fft_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
bluestein_fft<{tg_mem_size}, {in_T}, {out_T}>(
|
||||
const device {in_T}* in [[buffer(0)]],
|
||||
device {out_T}* out [[buffer(1)]],
|
||||
const device float2* w_q [[buffer(2)]],
|
||||
const device float2* w_k [[buffer(3)]],
|
||||
constant const int& length,
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view four_step_fft_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
four_step_fft<{tg_mem_size}, {in_T}, {out_T}, {step}, {real}>(
|
||||
const device {in_T}* in [[buffer(0)]],
|
||||
device {out_T}* out [[buffer(1)]],
|
||||
constant const int& n1,
|
||||
constant const int& n2,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
)";
|
35
mlx/backend/metal/jit/includes.h
Normal file
35
mlx/backend/metal/jit/includes.h
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
const char* utils();
|
||||
const char* binary_ops();
|
||||
const char* unary_ops();
|
||||
const char* ternary_ops();
|
||||
const char* reduce_utils();
|
||||
const char* gather();
|
||||
const char* scatter();
|
||||
|
||||
const char* arange();
|
||||
const char* unary();
|
||||
const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* fft();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* softmax();
|
||||
const char* sort();
|
||||
const char* reduce();
|
||||
|
||||
const char* gemm();
|
||||
const char* steel_gemm_fused();
|
||||
const char* steel_gemm_masked();
|
||||
const char* steel_gemm_splitk();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
|
||||
} // namespace mlx::core::metal
|
81
mlx/backend/metal/jit/indexing.h
Normal file
81
mlx/backend/metal/jit/indexing.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gather_kernels = R"(
|
||||
[[kernel]] void gather{0}_{3}_{6}(
|
||||
const device {1}* src [[buffer(0)]],
|
||||
device {1}* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const constant int* idx_shapes [[buffer(7)]],
|
||||
const constant size_t* idx_strides [[buffer(8)]],
|
||||
const constant int& idx_ndim [[buffer(9)]],
|
||||
{4}
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {{
|
||||
Indices<{2}, {3}> idxs{{
|
||||
{{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
return gather_impl<{1}, {2}, {3}, {6}>(
|
||||
src,
|
||||
out,
|
||||
src_shape,
|
||||
src_strides,
|
||||
src_ndim,
|
||||
slice_sizes,
|
||||
axes,
|
||||
idxs,
|
||||
index,
|
||||
grid_dim);
|
||||
}}
|
||||
)";
|
||||
|
||||
constexpr std::string_view scatter_kernels = R"(
|
||||
[[kernel]] void scatter_1d_index{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
|
||||
}}
|
||||
|
||||
[[kernel]] void scatter{0}_{4}(
|
||||
const device {1}* updates [[buffer(1)]],
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* upd_shape [[buffer(3)]],
|
||||
const constant size_t* upd_strides [[buffer(4)]],
|
||||
const constant size_t& upd_ndim [[buffer(5)]],
|
||||
const constant size_t& upd_size [[buffer(6)]],
|
||||
const constant int* out_shape [[buffer(7)]],
|
||||
const constant size_t* out_strides [[buffer(8)]],
|
||||
const constant size_t& out_ndim [[buffer(9)]],
|
||||
const constant int* axes [[buffer(10)]],
|
||||
const constant int* idx_shapes [[buffer(11)]],
|
||||
const constant size_t* idx_strides [[buffer(12)]],
|
||||
const constant int& idx_ndim [[buffer(13)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
|
||||
|
||||
return scatter_impl<{1}, {2}, {3}, {4}>(
|
||||
updates,
|
||||
out,
|
||||
upd_shape,
|
||||
upd_strides,
|
||||
upd_ndim,
|
||||
upd_size,
|
||||
out_shape,
|
||||
out_strides,
|
||||
out_ndim,
|
||||
axes,
|
||||
idxs,
|
||||
gid);
|
||||
}}
|
||||
)";
|
168
mlx/backend/metal/jit/reduce.h
Normal file
168
mlx/backend/metal/jit/reduce.h
Normal file
@@ -0,0 +1,168 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view reduce_init_kernels = R"(
|
||||
[[kernel]] void {0}(
|
||||
device {1}* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {{
|
||||
out[tid] = {2}<{1}>::init;
|
||||
}}
|
||||
)";
|
||||
|
||||
constexpr std::string_view reduce_kernels = R"(
|
||||
template [[host_name("all_{0}")]] [[kernel]] void
|
||||
all_reduce<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("colGeneral_{0}")]] [[kernel]] void
|
||||
col_reduce_general<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup {2}* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
template [[host_name("colSmall_{0}")]] [[kernel]] void
|
||||
col_reduce_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
|
||||
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
|
||||
row_reduce_general_med<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("rowGeneral_{0}")]] [[kernel]] void
|
||||
row_reduce_general<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device mlx_atomic<{2}>* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view reduce_non_atomic_kernels = R"(
|
||||
template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
|
||||
all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const device size_t& in_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint grid_size [[threads_per_grid]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
|
||||
col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
threadgroup {2}* local_data [[threadgroup(0)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 gid [[thread_position_in_grid]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
template [[host_name("colSmall_{0}")]] [[kernel]] void
|
||||
col_reduce_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& reduction_stride [[buffer(3)]],
|
||||
const constant size_t& out_size [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
const constant size_t& non_col_reductions [[buffer(8)]],
|
||||
const constant int* non_col_shapes [[buffer(9)]],
|
||||
const constant size_t* non_col_strides [[buffer(10)]],
|
||||
const constant int& non_col_ndim [[buffer(11)]],
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
|
||||
row_reduce_general_small<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint lid [[thread_position_in_grid]]);
|
||||
template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
|
||||
row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& reduction_size [[buffer(2)]],
|
||||
const constant size_t& out_size [[buffer(3)]],
|
||||
const constant size_t& non_row_reductions [[buffer(4)]],
|
||||
const constant int* shape [[buffer(5)]],
|
||||
const constant size_t* strides [[buffer(6)]],
|
||||
const constant int& ndim [[buffer(7)]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 lsize [[threads_per_threadgroup]],
|
||||
uint3 gsize [[threads_per_grid]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
26
mlx/backend/metal/jit/scan.h
Normal file
26
mlx/backend/metal/jit/scan.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view scan_kernels = R"(
|
||||
template [[host_name("contig_{0}")]] [[kernel]] void
|
||||
contiguous_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
template [[host_name("strided_{0}")]] [[kernel]] void
|
||||
strided_scan<{1}, {2}, {3}<{2}>, 4, {4}, {5}>(
|
||||
const device {1}* in [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant size_t& axis_size [[buffer(2)]],
|
||||
const constant size_t& stride [[buffer(3)]],
|
||||
uint2 gid [[thread_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
)";
|
23
mlx/backend/metal/jit/softmax.h
Normal file
23
mlx/backend/metal/jit/softmax.h
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view softmax_kernels = R"(
|
||||
template [[host_name("block_{0}")]] [[kernel]] void
|
||||
softmax_single_row<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint _lid [[thread_position_in_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
template [[host_name("looped_{0}")]] [[kernel]] void
|
||||
softmax_looped<{1}, {2}>(
|
||||
const device {1}* in,
|
||||
device {1}* out,
|
||||
constant int& axis_size,
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
)";
|
81
mlx/backend/metal/jit/sort.h
Normal file
81
mlx/backend/metal/jit/sort.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view block_sort_kernels = R"(
|
||||
template [[host_name("carg_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("ncarg_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("c_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("nc_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view multiblock_sort_kernels = R"(
|
||||
template [[host_name("sort_{0}")]] [[kernel]] void
|
||||
mb_block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {1}* out_vals [[buffer(1)]],
|
||||
device {2}* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("partition_{0}")]] [[kernel]] void
|
||||
mb_block_partition<{1}, {2}, true, {3}, {4}>(
|
||||
device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals [[buffer(1)]],
|
||||
const device {2}* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]);
|
||||
template [[host_name("merge_{0}")]] [[kernel]] void
|
||||
mb_block_merge<{1}, {2}, true, {3}, {4}>(
|
||||
const device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals_in [[buffer(1)]],
|
||||
const device {2}* dev_idxs_in [[buffer(2)]],
|
||||
device {1}* dev_vals_out [[buffer(3)]],
|
||||
device {2}* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
32
mlx/backend/metal/jit/steel_conv.h
Normal file
32
mlx/backend/metal/jit/steel_conv.h
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view steel_conv_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_conv_general_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* C [[buffer(2)]],
|
||||
const constant MLXConvParams<2>* params [[buffer(3)]],
|
||||
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
|
||||
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
|
||||
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
106
mlx/backend/metal/jit/steel_gemm.h
Normal file
106
mlx/backend/metal/jit/steel_gemm.h
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view steel_gemm_fused_kernels = R"(
|
||||
template [[host_name("{name}")]]
|
||||
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
|
||||
const device {itype} *A [[buffer(0)]],
|
||||
const device {itype} *B [[buffer(1)]],
|
||||
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
|
||||
device {itype} *D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
|
||||
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
|
||||
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
|
||||
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
|
||||
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
|
||||
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_masked_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
block_masked_gemm<
|
||||
{itype},
|
||||
{outmasktype},
|
||||
{opmasktype},
|
||||
{bm},
|
||||
{bn},
|
||||
{bk},
|
||||
{wm},
|
||||
{wn},
|
||||
{trans_a},
|
||||
{trans_b},
|
||||
{mn_aligned},
|
||||
{k_aligned}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {itype}* D [[buffer(3)]],
|
||||
const constant GEMMParams* params [[buffer(4)]],
|
||||
const constant int* batch_shape [[buffer(6)]],
|
||||
const constant size_t* batch_strides [[buffer(7)]],
|
||||
const device {outmasktype}* out_mask [[buffer(10)]],
|
||||
const device {opmasktype}* lhs_mask [[buffer(11)]],
|
||||
const device {opmasktype}* rhs_mask [[buffer(12)]],
|
||||
const constant int* mask_strides [[buffer(13)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk<
|
||||
{itype},
|
||||
{otype},
|
||||
{bm},
|
||||
{bn},
|
||||
{bk},
|
||||
{wm},
|
||||
{wn},
|
||||
{trans_a},
|
||||
{trans_b},
|
||||
{mn_aligned},
|
||||
{k_aligned}>(
|
||||
const device {itype}* A [[buffer(0)]],
|
||||
const device {itype}* B [[buffer(1)]],
|
||||
device {otype}* C [[buffer(2)]],
|
||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk_accum<{atype}, {otype}>(
|
||||
const device {atype}* C_split [[buffer(0)]],
|
||||
device {otype}* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemm_splitk_accum_axpby<{atype}, {otype}>(
|
||||
const device {atype}* C_split [[buffer(0)]],
|
||||
device {otype}* D [[buffer(1)]],
|
||||
const constant int& k_partitions [[buffer(2)]],
|
||||
const constant int& partition_stride [[buffer(3)]],
|
||||
const constant int& ldd [[buffer(4)]],
|
||||
const device {otype}* C [[buffer(5)]],
|
||||
const constant int& ldc [[buffer(6)]],
|
||||
const constant int& fdc [[buffer(7)]],
|
||||
const constant float& alpha [[buffer(8)]],
|
||||
const constant float& beta [[buffer(9)]],
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
)";
|
80
mlx/backend/metal/jit/ternary.h
Normal file
80
mlx/backend/metal/jit/ternary.h
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view ternary_kernels = R"(
|
||||
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void
|
||||
ternary_g_nd1<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void
|
||||
ternary_g_nd2<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void
|
||||
ternary_g_nd3<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 4>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
constant const size_t c_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 5>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
constant const size_t c_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
16
mlx/backend/metal/jit/unary.h
Normal file
16
mlx/backend/metal/jit/unary.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view unary_kernels = R"(
|
||||
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
device const int* in_shape,
|
||||
device const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
540
mlx/backend/metal/jit_kernels.cpp
Normal file
540
mlx/backend/metal/jit_kernels.cpp
Normal file
@@ -0,0 +1,540 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/binary.h"
|
||||
#include "mlx/backend/metal/jit/binary_two.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/fft.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/jit/ternary.h"
|
||||
#include "mlx/backend/metal/jit/unary.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
using namespace fmt::literals;
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::string op_name(const array& arr) {
|
||||
std::ostringstream op_t;
|
||||
arr.primitive().print(op_t);
|
||||
return op_t.str();
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source
|
||||
<< metal::utils() << metal::arange()
|
||||
<< fmt::format(arange_kernels, lib_name, get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
||||
<< fmt::format(
|
||||
unary_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(2);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
|
||||
<< fmt::format(
|
||||
binary_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(2);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops()
|
||||
<< metal::binary_two()
|
||||
<< fmt::format(
|
||||
binary_two_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
|
||||
<< fmt::format(
|
||||
ternary_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::copy()
|
||||
<< fmt::format(
|
||||
copy_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool precise,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::softmax()
|
||||
<< fmt::format(
|
||||
softmax_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
get_type_string(precise ? float32 : out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const std::string& reduce_type,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::string op_name = "Cum" + reduce_type;
|
||||
op_name[3] = toupper(op_name[3]);
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::scan()
|
||||
<< fmt::format(
|
||||
scan_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name,
|
||||
inclusive,
|
||||
reverse);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
int bn,
|
||||
int tn) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
block_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& idx,
|
||||
int bn,
|
||||
int tn) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
multiblock_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out) {
|
||||
auto lib = d.get_library(kernel_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils()
|
||||
<< fmt::format(
|
||||
reduce_init_kernels,
|
||||
kernel_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(kernel_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
|
||||
<< fmt::format(
|
||||
non_atomic ? reduce_non_atomic_kernels
|
||||
: reduce_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_type);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_fused()
|
||||
<< fmt::format(
|
||||
steel_gemm_fused_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
steel_gemm_splitk_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool axbpy) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_splitk()
|
||||
<< fmt::format(
|
||||
axbpy ? steel_gemm_splitk_accum_axbpy_kernels
|
||||
: steel_gemm_splitk_accum_kernels,
|
||||
"name"_a = lib_name,
|
||||
"atype"_a = get_type_string(in.dtype()),
|
||||
"otype"_a = get_type_string(out.dtype()));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto out_mask_type = mask_out.has_value()
|
||||
? get_type_string((*mask_out).dtype())
|
||||
: "nomask_t";
|
||||
auto op_mask_type =
|
||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||
kernel_source << metal::utils() << metal::gemm()
|
||||
<< metal::steel_gemm_masked()
|
||||
<< fmt::format(
|
||||
steel_gemm_masked_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"outmasktype"_a = out_mask_type,
|
||||
"opmasktype"_a = op_mask_type,
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"trans_a"_a = transpose_a,
|
||||
"trans_b"_a = transpose_b,
|
||||
"mn_aligned"_a = mn_aligned,
|
||||
"k_aligned"_a = k_aligned);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
int n_channel_specialization,
|
||||
bool small_filter) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv() << metal::steel_conv()
|
||||
<< fmt::format(
|
||||
steel_conv_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn,
|
||||
"n_channels"_a = n_channel_specialization,
|
||||
"small_filter"_a = small_filter);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::conv()
|
||||
<< metal::steel_conv_general()
|
||||
<< fmt::format(
|
||||
steel_conv_general_kernels,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"bk"_a = bk,
|
||||
"wm"_a = wm,
|
||||
"wn"_a = wn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const int tg_mem_size,
|
||||
const std::string& in_type,
|
||||
const std::string& out_type,
|
||||
int step,
|
||||
bool real,
|
||||
const metal::MTLFCList& func_consts) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
std::string kernel_string;
|
||||
if (lib_name.find("bluestein") != std::string::npos) {
|
||||
kernel_string = bluestein_fft_kernel;
|
||||
} else if (lib_name.find("rader") != std::string::npos) {
|
||||
kernel_string = rader_fft_kernel;
|
||||
} else if (lib_name.find("four_step") != std::string::npos) {
|
||||
kernel_string = four_step_fft_kernel;
|
||||
} else {
|
||||
kernel_string = fft_kernel;
|
||||
}
|
||||
kernel_source << metal::fft();
|
||||
if (lib_name.find("four_step") != std::string::npos) {
|
||||
kernel_source << fmt::format(
|
||||
kernel_string,
|
||||
"name"_a = lib_name,
|
||||
"tg_mem_size"_a = tg_mem_size,
|
||||
"in_T"_a = in_type,
|
||||
"out_T"_a = out_type,
|
||||
"step"_a = step,
|
||||
"real"_a = real);
|
||||
} else {
|
||||
kernel_source << fmt::format(
|
||||
kernel_string,
|
||||
"name"_a = lib_name,
|
||||
"tg_mem_size"_a = tg_mem_size,
|
||||
"in_T"_a = in_type,
|
||||
"out_T"_a = out_type);
|
||||
}
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
169
mlx/backend/metal/kernels.h
Normal file
169
mlx/backend/metal/kernels.h
Normal file
@@ -0,0 +1,169 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
MTL::ComputePipelineState* get_arange_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_copy_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_softmax_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool precise,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_scan_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const std::string& reduce_type,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
int bn,
|
||||
int tn);
|
||||
|
||||
MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& idx,
|
||||
int bn,
|
||||
int tn);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& in,
|
||||
const array& out,
|
||||
bool axbpy);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_a,
|
||||
bool transpose_b,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
bool mn_aligned,
|
||||
bool k_aligned);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn,
|
||||
int n_channel_specialization,
|
||||
bool small_filter);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
int bm,
|
||||
int bn,
|
||||
int bk,
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const int tg_mem_size,
|
||||
const std::string& in_type,
|
||||
const std::string& out_type,
|
||||
int step,
|
||||
bool real,
|
||||
const metal::MTLFCList& func_consts);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,26 +1,17 @@
|
||||
set(
|
||||
HEADERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/atomic.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/binary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/expm1f.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
utils.h
|
||||
steel/conv/params.h
|
||||
)
|
||||
|
||||
set(
|
||||
KERNELS
|
||||
"arange"
|
||||
"arg_reduce"
|
||||
"binary"
|
||||
"binary_two"
|
||||
"conv"
|
||||
"copy"
|
||||
"fft"
|
||||
"gemv"
|
||||
"quantized"
|
||||
@@ -28,15 +19,48 @@ set(
|
||||
"rms_norm"
|
||||
"layer_norm"
|
||||
"rope"
|
||||
"scan"
|
||||
"scaled_dot_product_attention"
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_JIT)
|
||||
set(
|
||||
KERNELS
|
||||
${KERNELS}
|
||||
"arange"
|
||||
"binary"
|
||||
"binary_two"
|
||||
"unary"
|
||||
"ternary"
|
||||
"copy"
|
||||
"softmax"
|
||||
"sort"
|
||||
"ternary"
|
||||
"unary"
|
||||
"gather"
|
||||
"scatter"
|
||||
"scan"
|
||||
"reduce"
|
||||
)
|
||||
set(
|
||||
HEADERS
|
||||
${HEADERS}
|
||||
atomic.h
|
||||
arange.h
|
||||
unary_ops.h
|
||||
unary.h
|
||||
binary_ops.h
|
||||
binary.h
|
||||
ternary.h
|
||||
copy.h
|
||||
fft.h
|
||||
fft/radix.h
|
||||
fft/readwrite.h
|
||||
softmax.h
|
||||
sort.h
|
||||
scan.h
|
||||
reduction/ops.h
|
||||
reduction/reduce_init.h
|
||||
reduction/reduce_all.h
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h
|
||||
)
|
||||
endif()
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
||||
@@ -68,23 +92,40 @@ foreach(KERNEL ${KERNELS})
|
||||
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal)
|
||||
file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h)
|
||||
|
||||
foreach(KERNEL ${STEEL_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
file(GLOB_RECURSE REDUCE_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.metal)
|
||||
file(GLOB_RECURSE REDUCE_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/reduction/*.h)
|
||||
|
||||
foreach(KERNEL ${REDUCE_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${REDUCE_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
if (NOT MLX_METAL_JIT)
|
||||
set(
|
||||
STEEL_KERNELS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
|
||||
)
|
||||
set(
|
||||
STEEL_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/conv/conv.h
|
||||
steel/conv/loader.h
|
||||
steel/conv/loaders/loader_channel_l.h
|
||||
steel/conv/loaders/loader_channel_n.h
|
||||
steel/conv/loaders/loader_general.h
|
||||
steel/conv/kernels/steel_conv.h
|
||||
steel/conv/kernels/steel_conv_general.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
)
|
||||
foreach(KERNEL ${STEEL_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
|
9
mlx/backend/metal/kernels/arange.h
Normal file
9
mlx/backend/metal/kernels/arange.h
Normal file
@@ -0,0 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
template <typename T>
|
||||
[[kernel]] void arange(
|
||||
constant const T& start,
|
||||
constant const T& step,
|
||||
device T* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = start + index * step;
|
||||
}
|
@@ -1,15 +1,8 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void arange(
|
||||
constant const T& start,
|
||||
constant const T& step,
|
||||
device T* out,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
out[index] = start + index * step;
|
||||
}
|
||||
#include "mlx/backend/metal/kernels/arange.h"
|
||||
|
||||
#define instantiate_arange(tname, type) \
|
||||
template [[host_name("arange" #tname)]] [[kernel]] void arange<type>( \
|
||||
@@ -18,7 +11,6 @@ template <typename T>
|
||||
device type* out, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
// clang-format off
|
||||
instantiate_arange(uint8, uint8_t)
|
||||
instantiate_arange(uint16, uint16_t)
|
||||
instantiate_arange(uint32, uint32_t)
|
||||
|
@@ -1,6 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
@@ -194,4 +193,4 @@ instantiate_arg_reduce(int32, int32_t)
|
||||
instantiate_arg_reduce(int64, int64_t)
|
||||
instantiate_arg_reduce(float16, half)
|
||||
instantiate_arg_reduce(float32, float)
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
||||
instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on
|
||||
|
@@ -4,7 +4,6 @@
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_stdlib>
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
|
@@ -6,9 +6,7 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
// No support for less than metal 3.0
|
||||
// anything greater has native bfloat
|
||||
#ifndef METAL_3_0
|
||||
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
|
@@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#ifndef METAL_3_0
|
||||
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
@@ -1,273 +1,113 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[index]);
|
||||
}
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[0]);
|
||||
}
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[index]);
|
||||
}
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
|
||||
metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op, int DIM>
|
||||
[[kernel]] void binary_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf)
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x != y;
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return metal::precise::atan2(y, x);
|
||||
}
|
||||
};
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
@@ -1,130 +1,24 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[0], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op()(a[index], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int DIM>
|
||||
[[kernel]] void binary_op_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_op_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
c[out_idx] = Op()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template \
|
||||
[[host_name(name)]] [[kernel]] void binary_op_##bopt<itype, otype, op>( \
|
||||
[[host_name(name)]] [[kernel]] void binary_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
binary_op_g_nd<itype, otype, op, dims>( \
|
||||
template [[host_name("g" #dims name)]] [[kernel]] void \
|
||||
binary_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -135,16 +29,16 @@ template <typename T, typename U, typename Op>
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name(name "_1")]] [[kernel]] void \
|
||||
binary_op_g_nd1<itype, otype, op>( \
|
||||
template [[host_name("g1" name)]] [[kernel]] void \
|
||||
binary_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void \
|
||||
binary_op_g_nd2<itype, otype, op>( \
|
||||
template [[host_name("g2" name)]] [[kernel]] void \
|
||||
binary_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -152,8 +46,8 @@ template <typename T, typename U, typename Op>
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void \
|
||||
binary_op_g_nd3<itype, otype, op>( \
|
||||
template [[host_name("g3" name)]] [[kernel]] void \
|
||||
binary_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -162,30 +56,28 @@ template <typename T, typename U, typename Op>
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name(name)]] [[kernel]] void binary_op_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name("gn" name)]] [[kernel]] void binary_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op) // clang-format on
|
||||
instantiate_binary_g(#name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd(#name #tname, itype, otype, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
@@ -194,22 +86,19 @@ template <typename T, typename U, typename Op>
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) // clang-format on
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) // clang-format on
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op) // clang-format on
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_types_bool(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||
@@ -223,9 +112,8 @@ template <typename T, typename U, typename Op>
|
||||
instantiate_binary_all(name, float16, half, bool, op) \
|
||||
instantiate_binary_all(name, float32, float, bool, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op) // clang-format on
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||
|
||||
// clang-format off
|
||||
instantiate_binary_types(add, Add)
|
||||
instantiate_binary_types(div, Divide)
|
||||
instantiate_binary_types_bool(eq, Equal)
|
||||
|
296
mlx/backend/metal/kernels/binary_ops.h
Normal file
296
mlx/backend/metal/kernels/binary_ops.h
Normal file
@@ -0,0 +1,296 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
struct Add {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
template <>
|
||||
float operator()(float x, float y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
half operator()(half x, half y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Divide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Equal {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NaNEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x == y || (metal::isnan(x) && metal::isnan(y));
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x == y ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) &&
|
||||
metal::isnan(y.imag)) ||
|
||||
(x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) ||
|
||||
(metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag);
|
||||
}
|
||||
};
|
||||
|
||||
struct Greater {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x > y;
|
||||
}
|
||||
};
|
||||
|
||||
struct GreaterEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x >= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Less {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x < y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LessEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x <= y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogAddExp {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
if (metal::isnan(x) || metal::isnan(y)) {
|
||||
return metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
constexpr T inf = metal::numeric_limits<T>::infinity();
|
||||
T maxval = metal::max(x, y);
|
||||
T minval = metal::min(x, y);
|
||||
return (minval == -inf || maxval == inf)
|
||||
? maxval
|
||||
: (maxval + log1p(metal::exp(minval - maxval)));
|
||||
};
|
||||
};
|
||||
|
||||
struct Maximum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::max(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Minimum {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
return metal::min(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
if (metal::isnan(x)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
if (metal::isnan(x.real) || metal::isnan(x.imag)) {
|
||||
return x;
|
||||
}
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x * y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NotEqual {
|
||||
template <typename T>
|
||||
bool operator()(T x, T y) {
|
||||
return x != y;
|
||||
}
|
||||
template <>
|
||||
bool operator()(complex64_t x, complex64_t y) {
|
||||
return x.real != y.real || x.imag != y.imag;
|
||||
}
|
||||
};
|
||||
|
||||
struct Power {
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
return metal::pow(base, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T base, T exp) {
|
||||
T res = 1;
|
||||
while (exp) {
|
||||
if (exp & 1) {
|
||||
res *= base;
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
auto x_theta = metal::atan(x.imag / x.real);
|
||||
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
|
||||
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
|
||||
auto phase = y.imag * x_ln_r + y.real * x_theta;
|
||||
return {mag * metal::cos(phase), mag * metal::sin(phase)};
|
||||
}
|
||||
};
|
||||
|
||||
struct Subtract {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x - y;
|
||||
}
|
||||
};
|
||||
|
||||
struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return metal::precise::atan2(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
struct DivMod {
|
||||
template <typename T>
|
||||
metal::array<T, 2> operator()(T x, T y) {
|
||||
return {FloorDivide{}(x, y), Remainder{}(x, y)};
|
||||
};
|
||||
};
|
140
mlx/backend/metal/kernels/binary_two.h
Normal file
140
mlx/backend/metal/kernels/binary_two.h
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[0], b[0]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[0], b[index]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[index], b[0]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto out = Op()(a[index], b[index]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[index] = out[0];
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
auto out = Op()(a[a_idx], b[b_idx]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int DIM>
|
||||
[[kernel]] void binary_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
auto out = Op()(a[idx.x], b[idx.y]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
auto out = Op()(a[idx.x], b[idx.y]);
|
||||
c[out_idx] = out[0];
|
||||
d[out_idx] = out[1];
|
||||
}
|
@@ -1,212 +1,24 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <metal_integer>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
struct FloorDivide {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x / y;
|
||||
}
|
||||
template <>
|
||||
float operator()(float x, float y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
half operator()(half x, half y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
template <>
|
||||
bfloat16_t operator()(bfloat16_t x, bfloat16_t y) {
|
||||
return trunc(x / y);
|
||||
}
|
||||
};
|
||||
|
||||
struct Remainder {
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & !metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
return x % y;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<metal::is_integral_v<T> & metal::is_signed_v<T>, T>
|
||||
operator()(T x, T y) {
|
||||
auto r = x % y;
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <typename T>
|
||||
metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T x, T y) {
|
||||
T r = fmod(x, y);
|
||||
if (r != 0 && (r < 0 != y < 0)) {
|
||||
r += y;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
template <>
|
||||
complex64_t operator()(complex64_t x, complex64_t y) {
|
||||
return x % y;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_s2s(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[0]);
|
||||
d[index] = Op2()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_ss(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[0]);
|
||||
d[index] = Op2()(a[0], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_sv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[0], b[index]);
|
||||
d[index] = Op2()(a[0], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_vs(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[index], b[0]);
|
||||
d[index] = Op2()(a[index], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_vv(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
c[index] = Op1()(a[index], b[index]);
|
||||
d[index] = Op2()(a[index], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd1(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto a_idx = elem_to_loc_1(index, a_stride);
|
||||
auto b_idx = elem_to_loc_1(index, b_stride);
|
||||
c[index] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[index] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_2(index, a_strides);
|
||||
auto b_idx = elem_to_loc_2(index, b_strides);
|
||||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
|
||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g_nd3(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto a_idx = elem_to_loc_3(index, a_strides);
|
||||
auto b_idx = elem_to_loc_3(index, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[a_idx], b[b_idx]);
|
||||
d[out_idx] = Op2()(a[a_idx], b[b_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2, int DIM>
|
||||
[[kernel]] void binary_op_g_nd(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int shape[DIM],
|
||||
constant const size_t a_strides[DIM],
|
||||
constant const size_t b_strides[DIM],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
|
||||
size_t out_idx =
|
||||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op1, typename Op2>
|
||||
[[kernel]] void binary_op_g(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
|
||||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
|
||||
c[out_idx] = Op1()(a[idx.x], b[idx.y]);
|
||||
d[out_idx] = Op2()(a[idx.x], b[idx.y]);
|
||||
}
|
||||
|
||||
#define instantiate_binary(name, itype, otype, op1, op2, bopt) \
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
binary_op_##bopt<itype, otype, op1, op2>( \
|
||||
binary_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op1, op2, dims) \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
binary_op_g_nd<itype, otype, op1, op2, dims>( \
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name("g" #dims name)]] [[kernel]] void \
|
||||
binary_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -217,10 +29,9 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name "_1")]] [[kernel]] void \
|
||||
binary_op_g_nd1<itype, otype, op1, op2>( \
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name("g1" name)]] [[kernel]] void \
|
||||
binary_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -228,8 +39,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void \
|
||||
binary_op_g_nd2<itype, otype, op1, op2>( \
|
||||
template [[host_name("g2" name)]] [[kernel]] void \
|
||||
binary_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -238,8 +49,8 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void \
|
||||
binary_op_g_nd3<itype, otype, op1, op2>( \
|
||||
template [[host_name("g3" name)]] [[kernel]] void \
|
||||
binary_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -248,12 +59,12 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op1, op2, 5) // clang-format on
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_binary_g(name, itype, otype, op1, op2) \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
binary_op_g<itype, otype, op2, op2>( \
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name("gn" name)]] [[kernel]] void \
|
||||
binary_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
@@ -265,33 +76,30 @@ template <typename T, typename U, typename Op1, typename Op2>
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op1, op2) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op1, op2, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op1, op2, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op1, op2, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op1, op2, vv) \
|
||||
instantiate_binary_g("g" #name #tname, itype, otype, op1, op2) \
|
||||
instantiate_binary_g_nd("g" #name #tname, itype, otype, op1, op2) // clang-format on
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g(#name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd(#name #tname, itype, otype, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_float(name, op1, op2) \
|
||||
instantiate_binary_all(name, float16, half, half, op1, op2) \
|
||||
instantiate_binary_all(name, float32, float, float, op1, op2) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op1, op2) // clang-format on
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_binary_types(name, op1, op2) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op1, op2) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op1, op2) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op1, op2) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op1, op2) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op1, op2) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op1, op2) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op1, op2) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op1, op2) \
|
||||
instantiate_binary_float(name, op1, op2)
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
instantiate_binary_types(divmod, FloorDivide, Remainder) // clang-format on
|
||||
instantiate_binary_types(divmod, DivMod) // clang-format on
|
||||
|
@@ -1,7 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
#include "mlx/backend/metal/kernels/ternary.h"
|
||||
#include "mlx/backend/metal/kernels/unary.h"
|
||||
|
||||
typedef half float16_t;
|
@@ -109,6 +109,7 @@ template <typename T, int N>
|
||||
bool valid = n < params->N;
|
||||
|
||||
// Unroll dimensions
|
||||
int kernel_stride = 1;
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
int os_ = (oS % params->oS[i]);
|
||||
int ws_ = (wS % params->wS[i]);
|
||||
@@ -125,7 +126,8 @@ template <typename T, int N>
|
||||
oS /= params->oS[i];
|
||||
wS /= params->wS[i];
|
||||
|
||||
out += ws_ * params->str[i];
|
||||
out += ws_ * kernel_stride;
|
||||
kernel_stride *= params->wS[i];
|
||||
}
|
||||
|
||||
if (valid) {
|
||||
@@ -648,4 +650,4 @@ winograd_conv_2d_output_transform(
|
||||
|
||||
// clang-format off
|
||||
instantiate_winograd_conv_2d(float32, float);
|
||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||
|
144
mlx/backend/metal/kernels/copy.h
Normal file
144
mlx/backend/metal/kernels/copy.h
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_g_nd(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_gg_nd(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
@@ -1,150 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_s(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_v(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
dst[index] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y;
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_g_nd(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
int64_t dst_idx =
|
||||
index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t& src_stride [[buffer(3)]],
|
||||
constant const int64_t& dst_stride [[buffer(4)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_1(index, src_stride);
|
||||
auto dst_idx = elem_to_loc_1(index, dst_stride);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint2 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_2(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_2(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg_nd3(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_3(index, src_strides);
|
||||
auto dst_idx = elem_to_loc_3(index, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, int DIM>
|
||||
[[kernel]] void copy_gg_nd(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc_nd<DIM>(index, src_shape, src_strides);
|
||||
auto dst_idx = elem_to_loc_nd<DIM>(index, src_shape, dst_strides);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_gg(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
constant const int* src_shape [[buffer(2)]],
|
||||
constant const int64_t* src_strides [[buffer(3)]],
|
||||
constant const int64_t* dst_strides [[buffer(4)]],
|
||||
constant const int& ndim [[buffer(5)]],
|
||||
uint3 index [[thread_position_in_grid]]) {
|
||||
auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim);
|
||||
auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim);
|
||||
dst[dst_idx] = static_cast<U>(src[src_idx]);
|
||||
}
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy(name, itype, otype, ctype) \
|
||||
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
|
||||
@@ -152,92 +11,90 @@ template <typename T, typename U>
|
||||
device otype* dst [[buffer(1)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name(name "_" #dims)]] [[kernel]] void \
|
||||
copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_" #dims)]] [[kernel]] void \
|
||||
copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name("g" #dims "_" name)]] [[kernel]] void \
|
||||
copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
|
||||
copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name(name "_1")]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name(name "_2")]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name(name "_3")]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name "_1")]] [[kernel]] void \
|
||||
copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_2")]] [[kernel]] void \
|
||||
copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g" name "_3")]] [[kernel]] void \
|
||||
copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gg1_" name )]] [[kernel]] void \
|
||||
copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("gg2_" name)]] [[kernel]] void \
|
||||
copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("gg3_" name)]] [[kernel]] void \
|
||||
copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name(name)]] [[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_copy("scopy" #tname, itype, otype, s) \
|
||||
instantiate_copy("vcopy" #tname, itype, otype, v) \
|
||||
instantiate_copy_g("gcopy" #tname, itype, otype) \
|
||||
instantiate_copy_g_nd("gcopy" #tname, itype, otype) // clang-format on
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_copy("s_copy" #tname, itype, otype, s) \
|
||||
instantiate_copy("v_copy" #tname, itype, otype, v) \
|
||||
instantiate_copy_g("copy" #tname, itype, otype) \
|
||||
instantiate_copy_g_nd("copy" #tname, itype, otype)
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
instantiate_copy_all(itname ##uint8, itype, uint8_t) \
|
||||
|
@@ -2,17 +2,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __METAL__
|
||||
#if defined __METAL__ || defined MLX_METAL_JIT
|
||||
#define MTL_CONST constant
|
||||
#else
|
||||
#define MTL_CONST
|
||||
#endif
|
||||
|
||||
static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5;
|
||||
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
||||
static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096;
|
||||
static MTL_CONST constexpr int RMS_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
||||
|
@@ -1,7 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_math>
|
||||
|
||||
/*
|
||||
@@ -67,4 +66,4 @@ float erfinv(float a) {
|
||||
p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1
|
||||
}
|
||||
return a * p;
|
||||
}
|
||||
}
|
||||
|
486
mlx/backend/metal/kernels/fft.h
Normal file
486
mlx/backend/metal/kernels/fft.h
Normal file
@@ -0,0 +1,486 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Metal FFT using Stockham's algorithm
|
||||
//
|
||||
// References:
|
||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||
|
||||
#include <metal_common>
|
||||
|
||||
#include "mlx/backend/metal/kernels/fft/radix.h"
|
||||
#include "mlx/backend/metal/kernels/fft/readwrite.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MAX_RADIX 13
|
||||
// Reached when elems_per_thread_ = 6, max_radix = 13
|
||||
// and some threads have to do 3 radix 6s requiring 18 float2s.
|
||||
#define MAX_OUTPUT_SIZE 18
|
||||
|
||||
// Specialize for a particular value of N at runtime
|
||||
STEEL_CONST bool inv_ [[function_constant(0)]];
|
||||
STEEL_CONST bool is_power_of_2_ [[function_constant(1)]];
|
||||
STEEL_CONST int elems_per_thread_ [[function_constant(2)]];
|
||||
// rader_m = n / rader_n
|
||||
STEEL_CONST int rader_m_ [[function_constant(3)]];
|
||||
// Stockham steps
|
||||
STEEL_CONST int radix_13_steps_ [[function_constant(4)]];
|
||||
STEEL_CONST int radix_11_steps_ [[function_constant(5)]];
|
||||
STEEL_CONST int radix_8_steps_ [[function_constant(6)]];
|
||||
STEEL_CONST int radix_7_steps_ [[function_constant(7)]];
|
||||
STEEL_CONST int radix_6_steps_ [[function_constant(8)]];
|
||||
STEEL_CONST int radix_5_steps_ [[function_constant(9)]];
|
||||
STEEL_CONST int radix_4_steps_ [[function_constant(10)]];
|
||||
STEEL_CONST int radix_3_steps_ [[function_constant(11)]];
|
||||
STEEL_CONST int radix_2_steps_ [[function_constant(12)]];
|
||||
// Rader steps
|
||||
STEEL_CONST int rader_13_steps_ [[function_constant(13)]];
|
||||
STEEL_CONST int rader_11_steps_ [[function_constant(14)]];
|
||||
STEEL_CONST int rader_8_steps_ [[function_constant(15)]];
|
||||
STEEL_CONST int rader_7_steps_ [[function_constant(16)]];
|
||||
STEEL_CONST int rader_6_steps_ [[function_constant(17)]];
|
||||
STEEL_CONST int rader_5_steps_ [[function_constant(18)]];
|
||||
STEEL_CONST int rader_4_steps_ [[function_constant(19)]];
|
||||
STEEL_CONST int rader_3_steps_ [[function_constant(20)]];
|
||||
STEEL_CONST int rader_2_steps_ [[function_constant(21)]];
|
||||
|
||||
// See "radix.h" for radix codelets
|
||||
typedef void (*RadixFunc)(thread float2*, thread float2*);
|
||||
|
||||
// Perform a single radix n butterfly with appropriate twiddles
|
||||
template <int radix, RadixFunc radix_func>
|
||||
METAL_FUNC void radix_butterfly(
|
||||
int i,
|
||||
int p,
|
||||
thread float2* x,
|
||||
thread short* indices,
|
||||
thread float2* y) {
|
||||
// i: the index in the overall DFT that we're processing.
|
||||
// p: the size of the DFTs we're merging at this step.
|
||||
// m: how many threads are working on this DFT.
|
||||
int k, j;
|
||||
|
||||
// Use faster bitwise operations when working with powers of two
|
||||
constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
|
||||
if (radix_p_2 && is_power_of_2_) {
|
||||
constexpr short power = __builtin_ctz(radix);
|
||||
k = i & (p - 1);
|
||||
j = ((i - k) << power) + k;
|
||||
} else {
|
||||
k = i % p;
|
||||
j = (i / p) * radix * p + k;
|
||||
}
|
||||
|
||||
// Apply twiddles
|
||||
if (p > 1) {
|
||||
float2 twiddle_1 = get_twiddle(k, radix * p);
|
||||
float2 twiddle = twiddle_1;
|
||||
x[1] = complex_mul(x[1], twiddle);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int t = 2; t < radix; t++) {
|
||||
twiddle = complex_mul(twiddle, twiddle_1);
|
||||
x[t] = complex_mul(x[t], twiddle);
|
||||
}
|
||||
}
|
||||
|
||||
radix_func(x, y);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int t = 0; t < radix; t++) {
|
||||
indices[t] = j + t * p;
|
||||
}
|
||||
}
|
||||
|
||||
// Perform all the radix steps required for a
|
||||
// particular radix size n.
|
||||
template <int radix, RadixFunc radix_func>
|
||||
METAL_FUNC void radix_n_steps(
|
||||
int i,
|
||||
thread int* p,
|
||||
int m,
|
||||
int n,
|
||||
int num_steps,
|
||||
thread float2* inputs,
|
||||
thread short* indices,
|
||||
thread float2* values,
|
||||
threadgroup float2* buf) {
|
||||
int m_r = n / radix;
|
||||
// When combining different sized radices, we have to do
|
||||
// multiple butterflies in a single thread.
|
||||
// E.g. n = 28 = 4 * 7
|
||||
// 4 threads, 7 elems_per_thread
|
||||
// All threads do 1 radix7 butterfly.
|
||||
// 3 threads do 2 radix4 butterflies.
|
||||
// 1 thread does 1 radix4 butterfly.
|
||||
int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix;
|
||||
|
||||
int index = 0;
|
||||
int r_index = 0;
|
||||
for (int s = 0; s < num_steps; s++) {
|
||||
for (int t = 0; t < max_radices_per_thread; t++) {
|
||||
index = i + t * m;
|
||||
if (index < m_r) {
|
||||
for (int r = 0; r < radix; r++) {
|
||||
inputs[r] = buf[index + r * m_r];
|
||||
}
|
||||
radix_butterfly<radix, radix_func>(
|
||||
index, *p, inputs, indices + t * radix, values + t * radix);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until all threads have read their inputs into thread local mem
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int t = 0; t < max_radices_per_thread; t++) {
|
||||
index = i + t * m;
|
||||
if (index < m_r) {
|
||||
for (int r = 0; r < radix; r++) {
|
||||
r_index = t * radix + r;
|
||||
buf[indices[r_index]] = values[r_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until all threads have written back to threadgroup mem
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
*p *= radix;
|
||||
}
|
||||
}
|
||||
|
||||
#define RADIX_STEP(radix, radix_func, num_steps) \
|
||||
radix_n_steps<radix, radix_func>( \
|
||||
fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
|
||||
|
||||
template <bool rader = false>
|
||||
METAL_FUNC void
|
||||
perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) {
|
||||
float2 inputs[MAX_RADIX];
|
||||
short indices[MAX_OUTPUT_SIZE];
|
||||
float2 values[MAX_OUTPUT_SIZE];
|
||||
|
||||
RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_);
|
||||
RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_);
|
||||
RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_);
|
||||
RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_);
|
||||
RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_);
|
||||
RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_);
|
||||
RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_);
|
||||
RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_);
|
||||
RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_);
|
||||
}
|
||||
|
||||
// Each FFT is computed entirely in shared GPU memory.
|
||||
//
|
||||
// N is decomposed into radix-n DFTs:
|
||||
// e.g. 128 = 2 * 4 * 4 * 4
|
||||
template <int tg_mem_size, typename in_T, typename out_T>
|
||||
[[kernel]] void fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
|
||||
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
int fft_idx = elem.z; // Thread index in DFT
|
||||
int m = grid.z; // Threads per DFT
|
||||
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
|
||||
threadgroup float2* buf = &shared_in[tg_idx];
|
||||
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write();
|
||||
}
|
||||
|
||||
template <int tg_mem_size, typename in_T, typename out_T>
|
||||
[[kernel]] void rader_fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
const device float2* raders_b_q [[buffer(2)]],
|
||||
const device short* raders_g_q [[buffer(3)]],
|
||||
const device short* raders_g_minus_q [[buffer(4)]],
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
constant const int& rader_n,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Use Rader's algorithm to compute fast FFTs
|
||||
// when a prime factor `p` of `n` is greater than 13 but
|
||||
// has `p - 1` Stockham decomposable into to prime factors <= 13.
|
||||
//
|
||||
// E.g. n = 102
|
||||
// = 2 * 3 * 17
|
||||
// . = 2 * 3 * RADER(16)
|
||||
// . = 2 * 3 * RADER(4 * 4)
|
||||
//
|
||||
// In numpy:
|
||||
// x_perm = x[g_q]
|
||||
// y = np.fft.fft(x_perm) * b_q
|
||||
// z = np.fft.ifft(y) + x[0]
|
||||
// out = z[g_minus_q]
|
||||
// out[0] = x[1:].sum()
|
||||
//
|
||||
// Where the g_q and g_minus_q are permutations formed
|
||||
// by the group under multiplicative modulo N using the
|
||||
// primitive root of N and b_q is a constant.
|
||||
// See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm
|
||||
//
|
||||
// Rader's uses fewer operations than Bluestein's and so
|
||||
// is more accurate. It's also faster in most cases.
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
|
||||
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = grid.z;
|
||||
|
||||
int fft_idx = elem.z;
|
||||
int tg_idx = elem.y * n;
|
||||
threadgroup float2* buf = &shared_in[tg_idx];
|
||||
|
||||
// rader_m = n / rader_n;
|
||||
int rader_m = rader_m_;
|
||||
|
||||
// We have to load two x_0s for each thread since sometimes
|
||||
// elems_per_thread_ crosses a boundary.
|
||||
// E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4
|
||||
// 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
short x_0_index =
|
||||
metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1);
|
||||
float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]};
|
||||
|
||||
// Do the Rader permutation in shared memory
|
||||
float2 temp[MAX_RADIX];
|
||||
int max_index = n - rader_m - 1;
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
short g_q = raders_g_q[index / rader_m];
|
||||
temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
buf[index + rader_m] = temp[e];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Rader FFT on x[rader_m:]
|
||||
int p = 1;
|
||||
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
|
||||
|
||||
// x_1 + ... + x_n is computed for us in the first FFT step so
|
||||
// we save it in the first rader_m indices of the array for later.
|
||||
int x_sum_index = metal::min(fft_idx, rader_m - 1);
|
||||
buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)];
|
||||
|
||||
float2 inv = {1.0f, -1.0f};
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
short interleaved_index =
|
||||
index / rader_m + (index % rader_m) * (rader_n - 1);
|
||||
temp[e] = complex_mul(
|
||||
buf[rader_m + interleaved_index],
|
||||
raders_b_q[interleaved_index % (rader_n - 1)]);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
buf[rader_m + index] = temp[e] * inv;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Rader IFFT on x[rader_m:]
|
||||
p = 1;
|
||||
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
|
||||
|
||||
float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1);
|
||||
short diff_index = index / (rader_n - 1) - x_0_index;
|
||||
temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
|
||||
}
|
||||
|
||||
// Use the sum of elements that was computed in the first FFT
|
||||
float2 x_sum = buf[x_0_index] + x_0[0];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
short g_q_index = index % (rader_n - 1);
|
||||
short g_q = raders_g_minus_q[g_q_index];
|
||||
short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
|
||||
buf[out_index] = temp[e];
|
||||
}
|
||||
|
||||
buf[x_0_index * rader_n] = x_sum;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
p = rader_n;
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write();
|
||||
}
|
||||
|
||||
template <int tg_mem_size, typename in_T, typename out_T>
|
||||
[[kernel]] void bluestein_fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
const device float2* w_q [[buffer(2)]],
|
||||
const device float2* w_k [[buffer(3)]],
|
||||
constant const int& length,
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Computes arbitrary length FFTs with Bluestein's algorithm
|
||||
//
|
||||
// In numpy:
|
||||
// bluestein_n = next_power_of_2(2*n - 1)
|
||||
// out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q)
|
||||
//
|
||||
// Where w_k and w_q are precomputed on CPU in high precision as:
|
||||
// w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2))
|
||||
// w_q = np.fft.fft(1/w_k[-n:])
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
|
||||
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load_padded(length, w_k);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
int fft_idx = elem.z; // Thread index in DFT
|
||||
int m = grid.z; // Threads per DFT
|
||||
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
|
||||
threadgroup float2* buf = &shared_in[tg_idx];
|
||||
|
||||
// fft
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
float2 inv = float2(1.0f, -1.0f);
|
||||
for (int t = 0; t < elems_per_thread_; t++) {
|
||||
int index = fft_idx + t * m;
|
||||
buf[index] = complex_mul(buf[index], w_q[index]) * inv;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// ifft
|
||||
p = 1;
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write_padded(length, w_k);
|
||||
}
|
||||
|
||||
template <
|
||||
int tg_mem_size,
|
||||
typename in_T,
|
||||
typename out_T,
|
||||
int step,
|
||||
bool real = false>
|
||||
[[kernel]] void four_step_fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
constant const int& n1,
|
||||
constant const int& n2,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Fast four step FFT implementation for powers of 2.
|
||||
int overall_n = n1 * n2;
|
||||
int n = step == 0 ? n1 : n2;
|
||||
int stride = step == 0 ? n2 : n1;
|
||||
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = grid.z;
|
||||
int fft_idx = elem.z;
|
||||
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
threadgroup float2* buf = &shared_in[elem.y * n];
|
||||
|
||||
using read_writer_t = ReadWriter<in_T, out_T, step, real>;
|
||||
read_writer_t read_writer = read_writer_t(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load_strided(stride, overall_n);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write_strided(stride, overall_n);
|
||||
}
|
@@ -1,199 +1,84 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Metal FFT using Stockham's algorithm
|
||||
//
|
||||
// References:
|
||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||
#include "mlx/backend/metal/kernels/fft.h"
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_math>
|
||||
#define instantiate_fft(tg_mem_size, in_T, out_T) \
|
||||
template [[host_name("fft_mem_" #tg_mem_size "_" #in_T \
|
||||
"_" #out_T)]] [[kernel]] void \
|
||||
fft<tg_mem_size, in_T, out_T>( \
|
||||
const device in_T* in [[buffer(0)]], \
|
||||
device out_T* out [[buffer(1)]], \
|
||||
constant const int& n, \
|
||||
constant const int& batch_size, \
|
||||
uint3 elem [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#define instantiate_rader(tg_mem_size, in_T, out_T) \
|
||||
template [[host_name("rader_fft_mem_" #tg_mem_size "_" #in_T \
|
||||
"_" #out_T)]] [[kernel]] void \
|
||||
rader_fft<tg_mem_size, in_T, out_T>( \
|
||||
const device in_T* in [[buffer(0)]], \
|
||||
device out_T* out [[buffer(1)]], \
|
||||
const device float2* raders_b_q [[buffer(2)]], \
|
||||
const device short* raders_g_q [[buffer(3)]], \
|
||||
const device short* raders_g_minus_q [[buffer(4)]], \
|
||||
constant const int& n, \
|
||||
constant const int& batch_size, \
|
||||
constant const int& rader_n, \
|
||||
uint3 elem [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
using namespace metal;
|
||||
#define instantiate_bluestein(tg_mem_size, in_T, out_T) \
|
||||
template [[host_name("bluestein_fft_mem_" #tg_mem_size "_" #in_T \
|
||||
"_" #out_T)]] [[kernel]] void \
|
||||
bluestein_fft<tg_mem_size, in_T, out_T>( \
|
||||
const device in_T* in [[buffer(0)]], \
|
||||
device out_T* out [[buffer(1)]], \
|
||||
const device float2* w_q [[buffer(2)]], \
|
||||
const device float2* w_k [[buffer(3)]], \
|
||||
constant const int& length, \
|
||||
constant const int& n, \
|
||||
constant const int& batch_size, \
|
||||
uint3 elem [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
float2 complex_mul(float2 a, float2 b) {
|
||||
float2 c;
|
||||
c.x = a.x * b.x - a.y * b.y;
|
||||
c.y = a.x * b.y + a.y * b.x;
|
||||
return c;
|
||||
}
|
||||
#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \
|
||||
template [[host_name("four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T \
|
||||
"_" #step "_" #real)]] [[kernel]] void \
|
||||
four_step_fft<tg_mem_size, in_T, out_T, step, real>( \
|
||||
const device in_T* in [[buffer(0)]], \
|
||||
device out_T* out [[buffer(1)]], \
|
||||
constant const int& n1, \
|
||||
constant const int& n2, \
|
||||
constant const int& batch_size, \
|
||||
uint3 elem [[thread_position_in_grid]], \
|
||||
uint3 grid [[threads_per_grid]]);
|
||||
|
||||
float2 get_twiddle(int k, int p) {
|
||||
float theta = -1.0f * k * M_PI_F / (2 * p);
|
||||
|
||||
float2 twiddle;
|
||||
twiddle.x = metal::fast::cos(theta);
|
||||
twiddle.y = metal::fast::sin(theta);
|
||||
return twiddle;
|
||||
}
|
||||
|
||||
// single threaded radix2 implemetation
|
||||
void radix2(
|
||||
int i,
|
||||
int p,
|
||||
int m,
|
||||
threadgroup float2* read_buf,
|
||||
threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
|
||||
float2 z = complex_mul(x_1, twiddle);
|
||||
|
||||
float2 y_0 = x_0 + z;
|
||||
float2 y_1 = x_0 - z;
|
||||
|
||||
int j = (i << 1) - k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
}
|
||||
|
||||
// single threaded radix4 implemetation
|
||||
void radix4(
|
||||
int i,
|
||||
int p,
|
||||
int m,
|
||||
threadgroup float2* read_buf,
|
||||
threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
float2 x_2 = read_buf[i + 2 * m];
|
||||
float2 x_3 = read_buf[i + 3 * m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
// e^a * e^b = e^(a + b)
|
||||
float2 twiddle_2 = complex_mul(twiddle, twiddle);
|
||||
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
|
||||
|
||||
x_1 = complex_mul(x_1, twiddle);
|
||||
x_2 = complex_mul(x_2, twiddle_2);
|
||||
x_3 = complex_mul(x_3, twiddle_3);
|
||||
|
||||
float2 minus_i;
|
||||
minus_i.x = 0;
|
||||
minus_i.y = -1;
|
||||
|
||||
// Hard coded twiddle factors for DFT4
|
||||
float2 z_0 = x_0 + x_2;
|
||||
float2 z_1 = x_0 - x_2;
|
||||
float2 z_2 = x_1 + x_3;
|
||||
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
|
||||
|
||||
float2 y_0 = z_0 + z_2;
|
||||
float2 y_1 = z_1 + z_3;
|
||||
float2 y_2 = z_0 - z_2;
|
||||
float2 y_3 = z_1 - z_3;
|
||||
|
||||
int j = ((i - k) << 2) + k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
write_buf[j + 2 * p] = y_2;
|
||||
write_buf[j + 3 * p] = y_3;
|
||||
}
|
||||
|
||||
// Each FFT is computed entirely in shared GPU memory.
|
||||
//
|
||||
// N is decomposed into radix-2 and radix-4 DFTs:
|
||||
// e.g. 128 = 2 * 4 * 4 * 4
|
||||
//
|
||||
// At each step we use n / 4 threads, each performing
|
||||
// a single-threaded radix-4 or radix-2 DFT.
|
||||
//
|
||||
// We provide the number of radix-2 and radix-4
|
||||
// steps at compile time for a ~20% performance boost.
|
||||
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
||||
[[kernel]] void fft(
|
||||
const device float2* in [[buffer(0)]],
|
||||
device float2* out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
||||
uint3 threads_per_grid [[threads_per_grid]]) {
|
||||
// Index of the DFT in batch
|
||||
int batch_idx = thread_position_in_grid.x * n;
|
||||
// The index in the DFT we're working on
|
||||
int i = thread_position_in_grid.y;
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = threads_per_grid.y;
|
||||
|
||||
// Allocate 2 shared memory buffers for Stockham.
|
||||
// We alternate reading from one and writing to the other at each radix step.
|
||||
threadgroup float2 shared_in[n];
|
||||
threadgroup float2 shared_out[n];
|
||||
|
||||
// Pointers to facilitate Stockham buffer swapping
|
||||
threadgroup float2* read_buf = shared_in;
|
||||
threadgroup float2* write_buf = shared_out;
|
||||
threadgroup float2* tmp;
|
||||
|
||||
// Copy input into shared memory
|
||||
shared_in[i] = in[batch_idx + i];
|
||||
shared_in[i + m] = in[batch_idx + i + m];
|
||||
shared_in[i + 2 * m] = in[batch_idx + i + 2 * m];
|
||||
shared_in[i + 3 * m] = in[batch_idx + i + 3 * m];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
|
||||
for (size_t r = 0; r < radix_2_steps; r++) {
|
||||
radix2(i, p, m * 2, read_buf, write_buf);
|
||||
radix2(i + m, p, m * 2, read_buf, write_buf);
|
||||
p *= 2;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
for (size_t r = 0; r < radix_4_steps; r++) {
|
||||
radix4(i, p, m, read_buf, write_buf);
|
||||
p *= 4;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
// Copy shared memory to output
|
||||
out[batch_idx + i] = read_buf[i];
|
||||
out[batch_idx + i + m] = read_buf[i + m];
|
||||
out[batch_idx + i + 2 * m] = read_buf[i + 2 * m];
|
||||
out[batch_idx + i + 3 * m] = read_buf[i + 3 * m];
|
||||
}
|
||||
|
||||
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
||||
template [[host_name("fft_" #name)]] [[kernel]] void \
|
||||
fft<n, radix_2_steps, radix_4_steps>( \
|
||||
const device float2* in [[buffer(0)]], \
|
||||
device float2* out [[buffer(1)]], \
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
||||
uint3 threads_per_grid [[threads_per_grid]]);
|
||||
|
||||
// Explicitly define kernels for each power of 2.
|
||||
// clang-format off
|
||||
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
||||
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
|
||||
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
|
||||
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
|
||||
instantiate_fft(512, 512, 1, 4)
|
||||
instantiate_fft(1024, 1024, 0, 5)
|
||||
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
||||
// TODO: implement 4 step FFT for larger n.
|
||||
instantiate_fft(2048, 2048, 1, 5) // clang-format on
|
||||
#define instantiate_ffts(tg_mem_size) \
|
||||
instantiate_fft(tg_mem_size, float2, float2) \
|
||||
instantiate_fft(tg_mem_size, float, float2) \
|
||||
instantiate_fft(tg_mem_size, float2, float) \
|
||||
instantiate_rader(tg_mem_size, float2, float2) \
|
||||
instantiate_rader(tg_mem_size, float, float2) \
|
||||
instantiate_rader(tg_mem_size, float2, float) \
|
||||
instantiate_bluestein(tg_mem_size, float2, float2) \
|
||||
instantiate_bluestein(tg_mem_size, float, float2) \
|
||||
instantiate_bluestein(tg_mem_size, float2, float) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/false) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/false) \
|
||||
instantiate_four_step(tg_mem_size, float, float2, 0, /*real=*/true) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/true) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/true) \
|
||||
instantiate_four_step(tg_mem_size, float2, float, 1, /*real=*/true)
|
||||
|
||||
// It's substantially faster to statically define the
|
||||
// threadgroup memory size rather than using
|
||||
// `setThreadgroupMemoryLength` on the compute encoder.
|
||||
// For non-power of 2 sizes we round up the shared memory.
|
||||
instantiate_ffts(256)
|
||||
instantiate_ffts(512)
|
||||
instantiate_ffts(1024)
|
||||
instantiate_ffts(2048)
|
||||
// 4096 is the max that will fit into 32KB of threadgroup memory.
|
||||
instantiate_ffts(4096) // clang-format on
|
||||
|
328
mlx/backend/metal/kernels/fft/radix.h
Normal file
328
mlx/backend/metal/kernels/fft/radix.h
Normal file
@@ -0,0 +1,328 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
/* Radix kernels
|
||||
|
||||
We provide optimized, single threaded Radix codelets
|
||||
for n=2,3,4,5,6,7,8,10,11,12,13.
|
||||
|
||||
For n=2,3,4,5,6 we hand write the codelets.
|
||||
For n=8,10,12 we combine smaller codelets.
|
||||
For n=7,11,13 we use Rader's algorithm which decomposes
|
||||
them into (n-1)=6,10,12 codelets. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_math>
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC float2 complex_mul(float2 a, float2 b) {
|
||||
return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
|
||||
}
|
||||
|
||||
// Complex mul followed by conjugate
|
||||
METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {
|
||||
return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x);
|
||||
}
|
||||
|
||||
// Compute an FFT twiddle factor
|
||||
METAL_FUNC float2 get_twiddle(int k, int p) {
|
||||
float theta = -2.0f * k * M_PI_F / p;
|
||||
|
||||
float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)};
|
||||
return twiddle;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix2(thread float2* x, thread float2* y) {
|
||||
y[0] = x[0] + x[1];
|
||||
y[1] = x[0] - x[1];
|
||||
}
|
||||
|
||||
METAL_FUNC void radix3(thread float2* x, thread float2* y) {
|
||||
float pi_2_3 = -0.8660254037844387;
|
||||
|
||||
float2 a_1 = x[1] + x[2];
|
||||
float2 a_2 = x[1] - x[2];
|
||||
|
||||
y[0] = x[0] + a_1;
|
||||
float2 b_1 = x[0] - 0.5 * a_1;
|
||||
float2 b_2 = pi_2_3 * a_2;
|
||||
|
||||
float2 b_2_j = {-b_2.y, b_2.x};
|
||||
y[1] = b_1 + b_2_j;
|
||||
y[2] = b_1 - b_2_j;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix4(thread float2* x, thread float2* y) {
|
||||
float2 z_0 = x[0] + x[2];
|
||||
float2 z_1 = x[0] - x[2];
|
||||
float2 z_2 = x[1] + x[3];
|
||||
float2 z_3 = x[1] - x[3];
|
||||
float2 z_3_i = {z_3.y, -z_3.x};
|
||||
|
||||
y[0] = z_0 + z_2;
|
||||
y[1] = z_1 + z_3_i;
|
||||
y[2] = z_0 - z_2;
|
||||
y[3] = z_1 - z_3_i;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix5(thread float2* x, thread float2* y) {
|
||||
float2 root_5_4 = 0.5590169943749475;
|
||||
float2 sin_2pi_5 = 0.9510565162951535;
|
||||
float2 sin_1pi_5 = 0.5877852522924731;
|
||||
|
||||
float2 a_1 = x[1] + x[4];
|
||||
float2 a_2 = x[2] + x[3];
|
||||
float2 a_3 = x[1] - x[4];
|
||||
float2 a_4 = x[2] - x[3];
|
||||
|
||||
float2 a_5 = a_1 + a_2;
|
||||
float2 a_6 = root_5_4 * (a_1 - a_2);
|
||||
float2 a_7 = x[0] - a_5 / 4;
|
||||
float2 a_8 = a_7 + a_6;
|
||||
float2 a_9 = a_7 - a_6;
|
||||
float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4;
|
||||
float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4;
|
||||
float2 a_10_j = {a_10.y, -a_10.x};
|
||||
float2 a_11_j = {a_11.y, -a_11.x};
|
||||
|
||||
y[0] = x[0] + a_5;
|
||||
y[1] = a_8 + a_10_j;
|
||||
y[2] = a_9 + a_11_j;
|
||||
y[3] = a_9 - a_11_j;
|
||||
y[4] = a_8 - a_10_j;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix6(thread float2* x, thread float2* y) {
|
||||
float sin_pi_3 = 0.8660254037844387;
|
||||
float2 a_1 = x[2] + x[4];
|
||||
float2 a_2 = x[0] - a_1 / 2;
|
||||
float2 a_3 = sin_pi_3 * (x[2] - x[4]);
|
||||
float2 a_4 = x[5] + x[1];
|
||||
float2 a_5 = x[3] - a_4 / 2;
|
||||
float2 a_6 = sin_pi_3 * (x[5] - x[1]);
|
||||
float2 a_7 = x[0] + a_1;
|
||||
|
||||
float2 a_3_i = {a_3.y, -a_3.x};
|
||||
float2 a_6_i = {a_6.y, -a_6.x};
|
||||
float2 a_8 = a_2 + a_3_i;
|
||||
float2 a_9 = a_2 - a_3_i;
|
||||
float2 a_10 = x[3] + a_4;
|
||||
float2 a_11 = a_5 + a_6_i;
|
||||
float2 a_12 = a_5 - a_6_i;
|
||||
|
||||
y[0] = a_7 + a_10;
|
||||
y[1] = a_8 - a_11;
|
||||
y[2] = a_9 + a_12;
|
||||
y[3] = a_7 - a_10;
|
||||
y[4] = a_8 + a_11;
|
||||
y[5] = a_9 - a_12;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix7(thread float2* x, thread float2* y) {
|
||||
// Rader's algorithm
|
||||
float2 inv = {1 / 6.0, -1 / 6.0};
|
||||
|
||||
// fft
|
||||
float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]};
|
||||
radix6(in1, y + 1);
|
||||
|
||||
y[0] = y[1] + x[0];
|
||||
|
||||
// b_q
|
||||
y[1] = complex_mul_conj(y[1], float2(-1, 0));
|
||||
y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879));
|
||||
y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629));
|
||||
y[4] = complex_mul_conj(y[4], float2(0, -2.64575131));
|
||||
y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629));
|
||||
y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879));
|
||||
|
||||
// ifft
|
||||
radix6(y + 1, x + 1);
|
||||
|
||||
y[1] = x[1] * inv + x[0];
|
||||
y[5] = x[2] * inv + x[0];
|
||||
y[4] = x[3] * inv + x[0];
|
||||
y[6] = x[4] * inv + x[0];
|
||||
y[2] = x[5] * inv + x[0];
|
||||
y[3] = x[6] * inv + x[0];
|
||||
}
|
||||
|
||||
METAL_FUNC void radix8(thread float2* x, thread float2* y) {
|
||||
float cos_pi_4 = 0.7071067811865476;
|
||||
float2 w_0 = {cos_pi_4, -cos_pi_4};
|
||||
float2 w_1 = {-cos_pi_4, -cos_pi_4};
|
||||
float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]};
|
||||
radix4(temp, x);
|
||||
radix4(temp + 4, x + 4);
|
||||
|
||||
y[0] = x[0] + x[4];
|
||||
y[4] = x[0] - x[4];
|
||||
float2 x_5 = complex_mul(x[5], w_0);
|
||||
y[1] = x[1] + x_5;
|
||||
y[5] = x[1] - x_5;
|
||||
float2 x_6 = {x[6].y, -x[6].x};
|
||||
y[2] = x[2] + x_6;
|
||||
y[6] = x[2] - x_6;
|
||||
float2 x_7 = complex_mul(x[7], w_1);
|
||||
y[3] = x[3] + x_7;
|
||||
y[7] = x[3] - x_7;
|
||||
}
|
||||
|
||||
template <bool raders_perm>
|
||||
METAL_FUNC void radix10(thread float2* x, thread float2* y) {
|
||||
float2 w[4];
|
||||
w[0] = {0.8090169943749475, -0.5877852522924731};
|
||||
w[1] = {0.30901699437494745, -0.9510565162951535};
|
||||
w[2] = {-w[1].x, w[1].y};
|
||||
w[3] = {-w[0].x, w[0].y};
|
||||
|
||||
if (raders_perm) {
|
||||
float2 temp[10] = {
|
||||
x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]};
|
||||
radix5(temp, x);
|
||||
radix5(temp + 5, x + 5);
|
||||
} else {
|
||||
float2 temp[10] = {
|
||||
x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]};
|
||||
radix5(temp, x);
|
||||
radix5(temp + 5, x + 5);
|
||||
}
|
||||
|
||||
y[0] = x[0] + x[5];
|
||||
y[5] = x[0] - x[5];
|
||||
for (int t = 1; t < 5; t++) {
|
||||
float2 a = complex_mul(x[t + 5], w[t - 1]);
|
||||
y[t] = x[t] + a;
|
||||
y[t + 5] = x[t] - a;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void radix11(thread float2* x, thread float2* y) {
|
||||
// Raders Algorithm
|
||||
float2 inv = {1 / 10.0, -1 / 10.0};
|
||||
|
||||
// fft
|
||||
radix10<true>(x + 1, y + 1);
|
||||
|
||||
y[0] = y[1] + x[0];
|
||||
|
||||
// b_q
|
||||
y[1] = complex_mul_conj(y[1], float2(-1, 0));
|
||||
y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649));
|
||||
y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656));
|
||||
y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479));
|
||||
y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150));
|
||||
y[6] = complex_mul_conj(y[6], float2(0, -3.31662479));
|
||||
y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150));
|
||||
y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479));
|
||||
y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656));
|
||||
y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649));
|
||||
|
||||
// ifft
|
||||
radix10<false>(y + 1, x + 1);
|
||||
|
||||
y[1] = x[1] * inv + x[0];
|
||||
y[6] = x[2] * inv + x[0];
|
||||
y[3] = x[3] * inv + x[0];
|
||||
y[7] = x[4] * inv + x[0];
|
||||
y[9] = x[5] * inv + x[0];
|
||||
y[10] = x[6] * inv + x[0];
|
||||
y[5] = x[7] * inv + x[0];
|
||||
y[8] = x[8] * inv + x[0];
|
||||
y[4] = x[9] * inv + x[0];
|
||||
y[2] = x[10] * inv + x[0];
|
||||
}
|
||||
|
||||
template <bool raders_perm>
|
||||
METAL_FUNC void radix12(thread float2* x, thread float2* y) {
|
||||
float2 w[6];
|
||||
float sin_pi_3 = 0.8660254037844387;
|
||||
w[0] = {sin_pi_3, -0.5};
|
||||
w[1] = {0.5, -sin_pi_3};
|
||||
w[2] = {0, -1};
|
||||
w[3] = {-0.5, -sin_pi_3};
|
||||
w[4] = {-sin_pi_3, -0.5};
|
||||
|
||||
if (raders_perm) {
|
||||
float2 temp[12] = {
|
||||
x[0],
|
||||
x[3],
|
||||
x[2],
|
||||
x[11],
|
||||
x[8],
|
||||
x[9],
|
||||
x[1],
|
||||
x[7],
|
||||
x[5],
|
||||
x[10],
|
||||
x[4],
|
||||
x[6]};
|
||||
radix6(temp, x);
|
||||
radix6(temp + 6, x + 6);
|
||||
} else {
|
||||
float2 temp[12] = {
|
||||
x[0],
|
||||
x[2],
|
||||
x[4],
|
||||
x[6],
|
||||
x[8],
|
||||
x[10],
|
||||
x[1],
|
||||
x[3],
|
||||
x[5],
|
||||
x[7],
|
||||
x[9],
|
||||
x[11]};
|
||||
radix6(temp, x);
|
||||
radix6(temp + 6, x + 6);
|
||||
}
|
||||
|
||||
y[0] = x[0] + x[6];
|
||||
y[6] = x[0] - x[6];
|
||||
for (int t = 1; t < 6; t++) {
|
||||
float2 a = complex_mul(x[t + 6], w[t - 1]);
|
||||
y[t] = x[t] + a;
|
||||
y[t + 6] = x[t] - a;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void radix13(thread float2* x, thread float2* y) {
|
||||
// Raders Algorithm
|
||||
float2 inv = {1 / 12.0, -1 / 12.0};
|
||||
|
||||
// fft
|
||||
radix12<true>(x + 1, y + 1);
|
||||
|
||||
y[0] = y[1] + x[0];
|
||||
|
||||
// b_q
|
||||
y[1] = complex_mul_conj(y[1], float2(-1, 0));
|
||||
y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669));
|
||||
y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823));
|
||||
y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161));
|
||||
y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690));
|
||||
y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267));
|
||||
y[7] = complex_mul_conj(y[7], float2(3.60555128, 0));
|
||||
y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267));
|
||||
y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690));
|
||||
y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161));
|
||||
y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823));
|
||||
y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669));
|
||||
|
||||
// ifft
|
||||
radix12<false>(y + 1, x + 1);
|
||||
|
||||
y[1] = x[1] * inv + x[0];
|
||||
y[7] = x[2] * inv + x[0];
|
||||
y[10] = x[3] * inv + x[0];
|
||||
y[5] = x[4] * inv + x[0];
|
||||
y[9] = x[5] * inv + x[0];
|
||||
y[11] = x[6] * inv + x[0];
|
||||
y[12] = x[7] * inv + x[0];
|
||||
y[6] = x[8] * inv + x[0];
|
||||
y[3] = x[9] * inv + x[0];
|
||||
y[8] = x[10] * inv + x[0];
|
||||
y[4] = x[11] * inv + x[0];
|
||||
y[2] = x[12] * inv + x[0];
|
||||
}
|
622
mlx/backend/metal/kernels/fft/readwrite.h
Normal file
622
mlx/backend/metal/kernels/fft/readwrite.h
Normal file
@@ -0,0 +1,622 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_common>
|
||||
|
||||
#include "mlx/backend/metal/kernels/fft/radix.h"
|
||||
|
||||
/* FFT helpers for reading and writing from/to device memory.
|
||||
|
||||
For many sizes, GPU FFTs are memory bandwidth bound so
|
||||
read/write performance is important.
|
||||
|
||||
Where possible, we read 128 bits sequentially in each thread,
|
||||
coalesced with accesses from adajcent threads for optimal performance.
|
||||
|
||||
We implement specialized reading/writing for:
|
||||
- FFT
|
||||
- RFFT
|
||||
- IRFFT
|
||||
|
||||
Each with support for:
|
||||
- Contiguous reads
|
||||
- Padded reads
|
||||
- Strided reads
|
||||
*/
|
||||
|
||||
#define MAX_RADIX 13
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <
|
||||
typename in_T,
|
||||
typename out_T,
|
||||
int step = 0,
|
||||
bool four_step_real = false>
|
||||
struct ReadWriter {
|
||||
const device in_T* in;
|
||||
threadgroup float2* buf;
|
||||
device out_T* out;
|
||||
int n;
|
||||
int batch_size;
|
||||
int elems_per_thread;
|
||||
uint3 elem;
|
||||
uint3 grid;
|
||||
int threads_per_tg;
|
||||
bool inv;
|
||||
|
||||
// Used for strided access
|
||||
int strided_device_idx = 0;
|
||||
int strided_shared_idx = 0;
|
||||
|
||||
METAL_FUNC ReadWriter(
|
||||
const device in_T* in_,
|
||||
threadgroup float2* buf_,
|
||||
device out_T* out_,
|
||||
const short n_,
|
||||
const int batch_size_,
|
||||
const short elems_per_thread_,
|
||||
const uint3 elem_,
|
||||
const uint3 grid_,
|
||||
const bool inv_)
|
||||
: in(in_),
|
||||
buf(buf_),
|
||||
out(out_),
|
||||
n(n_),
|
||||
batch_size(batch_size_),
|
||||
elems_per_thread(elems_per_thread_),
|
||||
elem(elem_),
|
||||
grid(grid_),
|
||||
inv(inv_) {
|
||||
// Account for padding on last threadgroup
|
||||
threads_per_tg = elem.x == grid.x - 1
|
||||
? (batch_size - (grid.x - 1) * grid.y) * grid.z
|
||||
: grid.y * grid.z;
|
||||
}
|
||||
|
||||
// ifft(x) = 1/n * conj(fft(conj(x)))
|
||||
METAL_FUNC float2 post_in(float2 elem) const {
|
||||
return inv ? float2(elem.x, -elem.y) : elem;
|
||||
}
|
||||
|
||||
// Handle float case for generic RFFT alg
|
||||
METAL_FUNC float2 post_in(float elem) const {
|
||||
return float2(elem, 0);
|
||||
}
|
||||
|
||||
METAL_FUNC float2 pre_out(float2 elem) const {
|
||||
return inv ? float2(elem.x / n, -elem.y / n) : elem;
|
||||
}
|
||||
|
||||
METAL_FUNC float2 pre_out(float2 elem, int length) const {
|
||||
return inv ? float2(elem.x / length, -elem.y / length) : elem;
|
||||
}
|
||||
|
||||
METAL_FUNC bool out_of_bounds() const {
|
||||
// Account for possible extra threadgroups
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
return grid_index >= batch_size;
|
||||
}
|
||||
|
||||
METAL_FUNC void load() const {
|
||||
int batch_idx = elem.x * grid.y * n;
|
||||
short tg_idx = elem.y * grid.z + elem.z;
|
||||
short max_index = grid.y * n - 2;
|
||||
|
||||
// 2 complex64s = 128 bits
|
||||
constexpr int read_width = 2;
|
||||
for (short e = 0; e < (elems_per_thread / read_width); e++) {
|
||||
short index = read_width * tg_idx + read_width * threads_per_tg * e;
|
||||
index = metal::min(index, max_index);
|
||||
// vectorized reads
|
||||
buf[index] = post_in(in[batch_idx + index]);
|
||||
buf[index + 1] = post_in(in[batch_idx + index + 1]);
|
||||
}
|
||||
max_index += 1;
|
||||
if (elems_per_thread % 2 != 0) {
|
||||
short index = tg_idx +
|
||||
read_width * threads_per_tg * (elems_per_thread / read_width);
|
||||
index = metal::min(index, max_index);
|
||||
buf[index] = post_in(in[batch_idx + index]);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void write() const {
|
||||
int batch_idx = elem.x * grid.y * n;
|
||||
short tg_idx = elem.y * grid.z + elem.z;
|
||||
short max_index = grid.y * n - 2;
|
||||
|
||||
constexpr int read_width = 2;
|
||||
for (short e = 0; e < (elems_per_thread / read_width); e++) {
|
||||
short index = read_width * tg_idx + read_width * threads_per_tg * e;
|
||||
index = metal::min(index, max_index);
|
||||
// vectorized reads
|
||||
out[batch_idx + index] = pre_out(buf[index]);
|
||||
out[batch_idx + index + 1] = pre_out(buf[index + 1]);
|
||||
}
|
||||
max_index += 1;
|
||||
if (elems_per_thread % 2 != 0) {
|
||||
short index = tg_idx +
|
||||
read_width * threads_per_tg * (elems_per_thread / read_width);
|
||||
index = metal::min(index, max_index);
|
||||
out[batch_idx + index] = pre_out(buf[index]);
|
||||
}
|
||||
}
|
||||
|
||||
// Padded IO for Bluestein's algorithm
|
||||
METAL_FUNC void load_padded(int length, const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length + elem.y * length;
|
||||
int fft_idx = elem.z;
|
||||
int m = grid.z;
|
||||
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n - 1);
|
||||
if (index < length) {
|
||||
float2 elem = post_in(in[batch_idx + index]);
|
||||
seq_buf[index] = complex_mul(elem, w_k[index]);
|
||||
} else {
|
||||
seq_buf[index] = 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void write_padded(int length, const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length + elem.y * length;
|
||||
int fft_idx = elem.z;
|
||||
int m = grid.z;
|
||||
float2 inv_factor = {1.0f / n, -1.0f / n};
|
||||
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n - 1);
|
||||
if (index < length) {
|
||||
float2 elem = seq_buf[index + length - 1] * inv_factor;
|
||||
out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Strided IO for four step FFT
|
||||
METAL_FUNC void compute_strided_indices(int stride, int overall_n) {
|
||||
// Use the batch threadgroup dimension to coalesce memory accesses:
|
||||
// e.g. stride = 12
|
||||
// device | shared mem
|
||||
// 0 1 2 3 | 0 12 - -
|
||||
// - - - - | 1 13 - -
|
||||
// - - - - | 2 14 - -
|
||||
// 12 13 14 15 | 3 15 - -
|
||||
int coalesce_width = grid.y;
|
||||
int tg_idx = elem.y * grid.z + elem.z;
|
||||
int outer_batch_size = stride / coalesce_width;
|
||||
|
||||
int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
|
||||
overall_n * (elem.x / outer_batch_size);
|
||||
strided_device_idx = strided_batch_idx +
|
||||
tg_idx / coalesce_width * elems_per_thread * stride +
|
||||
tg_idx % coalesce_width;
|
||||
strided_shared_idx = (tg_idx % coalesce_width) * n +
|
||||
tg_idx / coalesce_width * elems_per_thread;
|
||||
}
|
||||
|
||||
// Four Step FFT First Step
|
||||
METAL_FUNC void load_strided(int stride, int overall_n) {
|
||||
compute_strided_indices(stride, overall_n);
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
buf[strided_shared_idx + e] =
|
||||
post_in(in[strided_device_idx + e * stride]);
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void write_strided(int stride, int overall_n) {
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
float2 output = buf[strided_shared_idx + e];
|
||||
int combined_idx = (strided_device_idx + e * stride) % overall_n;
|
||||
int ij = (combined_idx / stride) * (combined_idx % stride);
|
||||
// Apply four step twiddles at end of first step
|
||||
float2 twiddle = get_twiddle(ij, overall_n);
|
||||
out[strided_device_idx + e * stride] = complex_mul(output, twiddle);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Four Step FFT Second Step
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::load_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
// Silence compiler warnings
|
||||
(void)stride;
|
||||
(void)overall_n;
|
||||
// Don't invert between steps
|
||||
bool default_inv = inv;
|
||||
inv = false;
|
||||
load();
|
||||
inv = default_inv;
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float2, /*step=*/1>::write_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
compute_strided_indices(stride, overall_n);
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
float2 output = buf[strided_shared_idx + e];
|
||||
out[strided_device_idx + e * stride] = pre_out(output, overall_n);
|
||||
}
|
||||
}
|
||||
|
||||
// For RFFT, we interleave batches of two real sequences into one complex one:
|
||||
//
|
||||
// z_k = x_k + j.y_k
|
||||
// X_k = (Z_k + Z_(N-k)*) / 2
|
||||
// Y_k = -j * ((Z_k - Z_(N-k)*) / 2)
|
||||
//
|
||||
// This roughly doubles the throughput over the regular FFT.
|
||||
template <>
|
||||
METAL_FUNC bool ReadWriter<float, float2>::out_of_bounds() const {
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
// We pack two sequences into one for RFFTs
|
||||
return grid_index * 2 >= batch_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::load() const {
|
||||
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_in =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n - 1);
|
||||
seq_buf[index].x = in[batch_idx + index];
|
||||
seq_buf[index].y = in[batch_idx + index + next_in];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::write() const {
|
||||
short n_over_2 = (n / 2) + 1;
|
||||
|
||||
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_out =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
|
||||
|
||||
float2 conj = {1, -1};
|
||||
float2 minus_j = {0, -1};
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n_over_2 - 1);
|
||||
// x_0 = z_0.real
|
||||
// y_0 = z_0.imag
|
||||
if (index == 0) {
|
||||
out[batch_idx + index] = {seq_buf[index].x, 0};
|
||||
out[batch_idx + index + next_out] = {seq_buf[index].y, 0};
|
||||
} else {
|
||||
float2 x_k = seq_buf[index];
|
||||
float2 x_n_minus_k = seq_buf[n - index] * conj;
|
||||
out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
|
||||
out[batch_idx + index + next_out] =
|
||||
complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::load_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_in =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n - 1);
|
||||
if (index < length) {
|
||||
float2 elem =
|
||||
float2(in[batch_idx + index], in[batch_idx + index + next_in]);
|
||||
seq_buf[index] = complex_mul(elem, w_k[index]);
|
||||
} else {
|
||||
seq_buf[index] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float, float2>::write_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int length_over_2 = (length / 2) + 1;
|
||||
int batch_idx =
|
||||
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
|
||||
? 0
|
||||
: length_over_2;
|
||||
|
||||
float2 conj = {1, -1};
|
||||
float2 inv_factor = {1.0f / n, -1.0f / n};
|
||||
float2 minus_j = {0, -1};
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
for (int e = 0; e < elems_per_thread / 2 + 1; e++) {
|
||||
int index = metal::min(fft_idx + e * m, length_over_2 - 1);
|
||||
// x_0 = z_0.real
|
||||
// y_0 = z_0.imag
|
||||
if (index == 0) {
|
||||
float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor);
|
||||
out[batch_idx + index] = float2(elem.x, 0);
|
||||
out[batch_idx + index + next_out] = float2(elem.y, 0);
|
||||
} else {
|
||||
float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor);
|
||||
float2 x_n_minus_k = complex_mul(
|
||||
w_k[length - index], seq_buf[length - index] * inv_factor);
|
||||
x_n_minus_k *= conj;
|
||||
// w_k should happen before this extraction
|
||||
out[batch_idx + index] = (x_k + x_n_minus_k) / 2;
|
||||
out[batch_idx + index + next_out] =
|
||||
complex_mul(((x_k - x_n_minus_k) / 2), minus_j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For IRFFT, we do the opposite
|
||||
//
|
||||
// Z_k = X_k + j.Y_k
|
||||
// x_k = Re(Z_k)
|
||||
// Y_k = Imag(Z_k)
|
||||
template <>
|
||||
METAL_FUNC bool ReadWriter<float2, float>::out_of_bounds() const {
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
// We pack two sequences into one for IRFFTs
|
||||
return grid_index * 2 >= batch_size;
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float>::load() const {
|
||||
short n_over_2 = (n / 2) + 1;
|
||||
int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_in =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
float2 conj = {1, -1};
|
||||
float2 plus_j = {0, 1};
|
||||
|
||||
for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
|
||||
int index = metal::min(fft_idx + t * m, n_over_2 - 1);
|
||||
float2 x = in[batch_idx + index];
|
||||
float2 y = in[batch_idx + index + next_in];
|
||||
// NumPy forces first input to be real
|
||||
bool first_val = index == 0;
|
||||
// NumPy forces last input on even irffts to be real
|
||||
bool last_val = n % 2 == 0 && index == n_over_2 - 1;
|
||||
if (first_val || last_val) {
|
||||
x = float2(x.x, 0);
|
||||
y = float2(y.x, 0);
|
||||
}
|
||||
seq_buf[index] = x + complex_mul(y, plus_j);
|
||||
seq_buf[index].y = -seq_buf[index].y;
|
||||
if (index > 0 && !last_val) {
|
||||
seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j);
|
||||
seq_buf[n - index].y = -seq_buf[n - index].y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float>::write() const {
|
||||
int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_out =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = metal::min(fft_idx + e * m, n - 1);
|
||||
out[batch_idx + index] = seq_buf[index].x / n;
|
||||
out[batch_idx + index + next_out] = seq_buf[index].y / -n;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float>::load_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int n_over_2 = (n / 2) + 1;
|
||||
int length_over_2 = (length / 2) + 1;
|
||||
|
||||
int batch_idx =
|
||||
elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n;
|
||||
|
||||
// No out of bounds accesses on odd batch sizes
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1
|
||||
? 0
|
||||
: length_over_2;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
float2 conj = {1, -1};
|
||||
float2 plus_j = {0, 1};
|
||||
|
||||
for (int t = 0; t < elems_per_thread / 2 + 1; t++) {
|
||||
int index = metal::min(fft_idx + t * m, n_over_2 - 1);
|
||||
float2 x = in[batch_idx + index];
|
||||
float2 y = in[batch_idx + index + next_in];
|
||||
if (index < length_over_2) {
|
||||
bool last_val = length % 2 == 0 && index == length_over_2 - 1;
|
||||
if (last_val) {
|
||||
x = float2(x.x, 0);
|
||||
y = float2(y.x, 0);
|
||||
}
|
||||
float2 elem1 = x + complex_mul(y, plus_j);
|
||||
seq_buf[index] = complex_mul(elem1 * conj, w_k[index]);
|
||||
if (index > 0 && !last_val) {
|
||||
float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j);
|
||||
seq_buf[length - index] =
|
||||
complex_mul(elem2 * conj, w_k[length - index]);
|
||||
}
|
||||
} else {
|
||||
short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2);
|
||||
seq_buf[pad_index] = 0;
|
||||
seq_buf[pad_index + 1] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void ReadWriter<float2, float>::write_padded(
|
||||
int length,
|
||||
const device float2* w_k) const {
|
||||
int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2;
|
||||
threadgroup float2* seq_buf = buf + elem.y * n + length - 1;
|
||||
|
||||
int grid_index = elem.x * grid.y + elem.y;
|
||||
short next_out =
|
||||
batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length;
|
||||
|
||||
short m = grid.z;
|
||||
short fft_idx = elem.z;
|
||||
|
||||
float2 inv_factor = {1.0f / n, -1.0f / n};
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int index = fft_idx + e * m;
|
||||
if (index < length) {
|
||||
float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]);
|
||||
out[batch_idx + index] = output.x / length;
|
||||
out[batch_idx + index + next_out] = output.y / -length;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Four Step RFFT
|
||||
template <>
|
||||
METAL_FUNC void
|
||||
ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::load_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
// Silence compiler warnings
|
||||
(void)stride;
|
||||
(void)overall_n;
|
||||
// Don't invert between steps
|
||||
bool default_inv = inv;
|
||||
inv = false;
|
||||
load();
|
||||
inv = default_inv;
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void
|
||||
ReadWriter<float2, float2, /*step=*/1, /*real=*/true>::write_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
int overall_n_over_2 = overall_n / 2 + 1;
|
||||
int coalesce_width = grid.y;
|
||||
int tg_idx = elem.y * grid.z + elem.z;
|
||||
int outer_batch_size = stride / coalesce_width;
|
||||
|
||||
int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width +
|
||||
overall_n_over_2 * (elem.x / outer_batch_size);
|
||||
strided_device_idx = strided_batch_idx +
|
||||
tg_idx / coalesce_width * elems_per_thread / 2 * stride +
|
||||
tg_idx % coalesce_width;
|
||||
strided_shared_idx = (tg_idx % coalesce_width) * n +
|
||||
tg_idx / coalesce_width * elems_per_thread / 2;
|
||||
for (int e = 0; e < elems_per_thread / 2; e++) {
|
||||
float2 output = buf[strided_shared_idx + e];
|
||||
out[strided_device_idx + e * stride] = output;
|
||||
}
|
||||
|
||||
// Add on n/2 + 1 element
|
||||
if (tg_idx == 0 && elem.x % outer_batch_size == 0) {
|
||||
out[strided_batch_idx + overall_n / 2] = buf[n / 2];
|
||||
}
|
||||
}
|
||||
|
||||
// Four Step IRFFT
|
||||
template <>
|
||||
METAL_FUNC void
|
||||
ReadWriter<float2, float2, /*step=*/0, /*real=*/true>::load_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
int overall_n_over_2 = overall_n / 2 + 1;
|
||||
auto conj = float2(1, -1);
|
||||
|
||||
compute_strided_indices(stride, overall_n);
|
||||
// Translate indices in terms of N - k
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
int device_idx = strided_device_idx + e * stride;
|
||||
int overall_batch = device_idx / overall_n;
|
||||
int overall_index = device_idx % overall_n;
|
||||
if (overall_index < overall_n_over_2) {
|
||||
device_idx -= overall_batch * (overall_n - overall_n_over_2);
|
||||
buf[strided_shared_idx + e] = in[device_idx] * conj;
|
||||
} else {
|
||||
int conj_idx = overall_n - overall_index;
|
||||
device_idx = overall_batch * overall_n_over_2 + conj_idx;
|
||||
buf[strided_shared_idx + e] = in[device_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void
|
||||
ReadWriter<float2, float, /*step=*/1, /*real=*/true>::load_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
// Silence compiler warnings
|
||||
(void)stride;
|
||||
(void)overall_n;
|
||||
bool default_inv = inv;
|
||||
inv = false;
|
||||
load();
|
||||
inv = default_inv;
|
||||
}
|
||||
|
||||
template <>
|
||||
METAL_FUNC void
|
||||
ReadWriter<float2, float, /*step=*/1, /*real=*/true>::write_strided(
|
||||
int stride,
|
||||
int overall_n) {
|
||||
compute_strided_indices(stride, overall_n);
|
||||
|
||||
for (int e = 0; e < elems_per_thread; e++) {
|
||||
out[strided_device_idx + e * stride] =
|
||||
pre_out(buf[strided_shared_idx + e], overall_n).x;
|
||||
}
|
||||
}
|
45
mlx/backend/metal/kernels/gather.h
Normal file
45
mlx/backend/metal/kernels/gather.h
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
@@ -1,173 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/indexing.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM>
|
||||
METAL_FUNC void gather_impl(
|
||||
const device T* src [[buffer(0)]],
|
||||
device T* out [[buffer(1)]],
|
||||
const constant int* src_shape [[buffer(2)]],
|
||||
const constant size_t* src_strides [[buffer(3)]],
|
||||
const constant size_t& src_ndim [[buffer(4)]],
|
||||
const constant int* slice_sizes [[buffer(5)]],
|
||||
const constant int* axes [[buffer(6)]],
|
||||
const thread Indices<IdxT, NIDX>& indices,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
auto ind_idx = index.x;
|
||||
auto ind_offset = index.y;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
size_t idx_loc;
|
||||
if (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else if (IDX_NDIM == 1) {
|
||||
idx_loc = ind_idx * indices.strides[indices.ndim * i];
|
||||
} else {
|
||||
idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
}
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
|
||||
size_t out_idx = index.y + static_cast<size_t>(grid_dim.y) * index.x;
|
||||
out[out_idx] = src[src_offset + src_idx];
|
||||
}
|
||||
|
||||
#define make_gather_impl(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM> \
|
||||
[[kernel]] void gather( \
|
||||
const device T* src [[buffer(0)]], \
|
||||
device T* out [[buffer(1)]], \
|
||||
const constant int* src_shape [[buffer(2)]], \
|
||||
const constant size_t* src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int* slice_sizes [[buffer(5)]], \
|
||||
const constant int* axes [[buffer(6)]], \
|
||||
const constant int* idx_shapes [[buffer(7)]], \
|
||||
const constant size_t* idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(IdxT) uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]) { \
|
||||
Indices<IdxT, NIDX> idxs{ \
|
||||
{{IDX_ARR()}}, idx_shapes, idx_strides, idx_ndim}; \
|
||||
\
|
||||
return gather_impl<T, IdxT, NIDX, IDX_NDIM>( \
|
||||
src, \
|
||||
out, \
|
||||
src_shape, \
|
||||
src_strides, \
|
||||
src_ndim, \
|
||||
slice_sizes, \
|
||||
axes, \
|
||||
idxs, \
|
||||
index, \
|
||||
grid_dim); \
|
||||
}
|
||||
|
||||
#define make_gather(n) make_gather_impl(IDX_ARG_##n, IDX_ARR_##n)
|
||||
|
||||
make_gather(0) make_gather(1) make_gather(2) make_gather(3) make_gather(4)
|
||||
make_gather(5) make_gather(6) make_gather(7) make_gather(8) make_gather(9)
|
||||
make_gather(10)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather instantiations
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \
|
||||
template [[host_name("gather" name "_" #nidx "" #nd_name)]] [[kernel]] void \
|
||||
gather<src_t, idx_t, nidx, nd>( \
|
||||
const device src_t* src [[buffer(0)]], \
|
||||
device src_t* out [[buffer(1)]], \
|
||||
const constant int* src_shape [[buffer(2)]], \
|
||||
const constant size_t* src_strides [[buffer(3)]], \
|
||||
const constant size_t& src_ndim [[buffer(4)]], \
|
||||
const constant int* slice_sizes [[buffer(5)]], \
|
||||
const constant int* axes [[buffer(6)]], \
|
||||
const constant int* idx_shapes [[buffer(7)]], \
|
||||
const constant size_t* idx_strides [[buffer(8)]], \
|
||||
const constant int& idx_ndim [[buffer(9)]], \
|
||||
IDX_ARG(idx_t) uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \
|
||||
instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather4(name, src_t, idx_t, nidx) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \
|
||||
instantiate_gather5(name, src_t, idx_t, nidx, 2, )
|
||||
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||
instantiate_gather4("int8", int8_t, bool, 0)
|
||||
instantiate_gather4("int16", int16_t, bool, 0)
|
||||
instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
instantiate_gather(uint8, uint8_t)
|
||||
instantiate_gather(uint16, uint16_t)
|
||||
instantiate_gather(uint32, uint32_t)
|
||||
instantiate_gather(uint64, uint64_t)
|
||||
instantiate_gather(int8, int8_t)
|
||||
instantiate_gather(int16, int16_t)
|
||||
instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t) // clang-format on
|
@@ -1,13 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Indexing utils
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<const device IdxT*, NIDX> buffers;
|
||||
@@ -24,31 +20,3 @@ METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
}
|
||||
|
||||
#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]],
|
||||
|
||||
#define IDX_ARG_0(idx_t)
|
||||
#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21)
|
||||
#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22)
|
||||
#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23)
|
||||
#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24)
|
||||
#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25)
|
||||
#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26)
|
||||
#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27)
|
||||
#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28)
|
||||
#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29)
|
||||
#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30)
|
||||
|
||||
#define IDX_ARR_N(n) idx##n,
|
||||
|
||||
#define IDX_ARR_0()
|
||||
#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21)
|
||||
#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22)
|
||||
#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23)
|
||||
#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24)
|
||||
#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25)
|
||||
#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26)
|
||||
#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27)
|
||||
#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28)
|
||||
#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29)
|
||||
#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30)
|
@@ -601,14 +601,18 @@ METAL_FUNC void qvm_impl(
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int num_simdgroups = 8;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int tn = 32 / pack_factor;
|
||||
constexpr int blocksize = SIMD_SIZE;
|
||||
|
||||
typedef float U;
|
||||
typedef struct {
|
||||
uint32_t wi[tn];
|
||||
} vec_w;
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread U result[pack_factor] = {0};
|
||||
thread vec_w w_local;
|
||||
thread U result[tn * pack_factor] = {0};
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_local = 0;
|
||||
@@ -616,11 +620,12 @@ METAL_FUNC void qvm_impl(
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / pack_factor;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
||||
w += out_col / pack_factor;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
x += tid.y * in_vec_size;
|
||||
int out_col =
|
||||
tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
|
||||
w += out_col / pack_factor + simd_lid * out_vec_size_w;
|
||||
scales += out_col / group_size + simd_lid * out_vec_size_g;
|
||||
biases += out_col / group_size + simd_lid * out_vec_size_g;
|
||||
x += tid.y * in_vec_size + simd_lid;
|
||||
y += tid.y * out_vec_size + out_col;
|
||||
|
||||
if (out_col >= out_vec_size) {
|
||||
@@ -628,40 +633,61 @@ METAL_FUNC void qvm_impl(
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of blocksize
|
||||
int i = 0;
|
||||
for (; i + blocksize <= in_vec_size; i += blocksize) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
int remaining = in_vec_size % blocksize;
|
||||
if (remaining == 0) {
|
||||
for (int i = 0; i < in_vec_size; i += blocksize) {
|
||||
x_local = *x;
|
||||
scale = *scales;
|
||||
bias = *biases;
|
||||
w_local = *((device vec_w*)w);
|
||||
|
||||
qouter<U, pack_factor, bits>(
|
||||
qouter<U, tn * pack_factor, bits>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
|
||||
x += blocksize;
|
||||
scales += blocksize * out_vec_size_g;
|
||||
biases += blocksize * out_vec_size_g;
|
||||
w += blocksize * out_vec_size_w;
|
||||
}
|
||||
} else {
|
||||
for (int i = blocksize; i < in_vec_size; i += blocksize) {
|
||||
x_local = *x;
|
||||
scale = *scales;
|
||||
bias = *biases;
|
||||
w_local = *((device vec_w*)w);
|
||||
|
||||
qouter<U, tn * pack_factor, bits>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
|
||||
x += blocksize;
|
||||
scales += blocksize * out_vec_size_g;
|
||||
biases += blocksize * out_vec_size_g;
|
||||
w += blocksize * out_vec_size_w;
|
||||
}
|
||||
if (static_cast<int>(simd_lid) < remaining) {
|
||||
x_local = *x;
|
||||
scale = *scales;
|
||||
bias = *biases;
|
||||
w_local = *((device vec_w*)w);
|
||||
} else {
|
||||
x_local = 0;
|
||||
scale = 0;
|
||||
bias = 0;
|
||||
}
|
||||
qouter<U, tn * pack_factor, bits>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
}
|
||||
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
} else {
|
||||
x_local = 0;
|
||||
scale = 0;
|
||||
bias = 0;
|
||||
w_local = 0;
|
||||
}
|
||||
qouter<U, pack_factor, bits>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k = 0; k < pack_factor; k++) {
|
||||
for (int k = 0; k < tn * pack_factor; k++) {
|
||||
result[k] = simd_sum(result[k]);
|
||||
}
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k = 0; k < pack_factor; k++) {
|
||||
for (int k = 0; k < tn * pack_factor; k++) {
|
||||
y[k] = static_cast<T>(result[k]);
|
||||
}
|
||||
}
|
||||
|
4
mlx/backend/metal/kernels/reduce.h
Normal file
4
mlx/backend/metal/kernels/reduce.h
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_all.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_col.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_row.h"
|
293
mlx/backend/metal/kernels/reduce.metal
Normal file
293
mlx/backend/metal/kernels/reduce.metal
Normal file
@@ -0,0 +1,293 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_init.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
|
||||
#define instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
inst_f(name, float16, half, op) \
|
||||
inst_f(name, float32, float, op) \
|
||||
inst_f(name, bfloat16, bfloat16_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
inst_f(name, uint8, uint8_t, op) \
|
||||
inst_f(name, uint16, uint16_t, op) \
|
||||
inst_f(name, uint32, uint32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_ints(inst_f, name, op) \
|
||||
inst_f(name, int8, int8_t, op) \
|
||||
inst_f(name, int16, int16_t, op) \
|
||||
inst_f(name, int32, int32_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_64b(inst_f, name, op) \
|
||||
inst_f(name, int64, int64_t, op) \
|
||||
inst_f(name, uint64, uint64_t, op)
|
||||
|
||||
#define instantiate_reduce_helper_types(inst_f, name, op) \
|
||||
instantiate_reduce_helper_floats(inst_f, name, op) \
|
||||
instantiate_reduce_helper_uints(inst_f, name, op) \
|
||||
instantiate_reduce_helper_ints(inst_f, name, op)
|
||||
|
||||
#define instantiate_reduce_ops(inst_f, type_f) \
|
||||
type_f(inst_f, sum, Sum) \
|
||||
type_f(inst_f, prod, Prod) \
|
||||
type_f(inst_f, min, Min) \
|
||||
type_f(inst_f, max, Max)
|
||||
|
||||
// Special case for bool reductions
|
||||
#define instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, tname, itype, otype, op) \
|
||||
inst_f(name##tname, itype, otype, op)
|
||||
|
||||
#define instantiate_reduce_from_types(inst_f, name, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, bool_, bool, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint8, uint8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint16, uint16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint32, uint32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, uint64, uint64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int8, int8_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int16, int16_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int32, int32_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, int64, int64_t, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, name, float16, half, otype, op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
float32, \
|
||||
float, \
|
||||
otype, \
|
||||
op) \
|
||||
instantiate_reduce_from_types_helper( \
|
||||
inst_f, \
|
||||
name, \
|
||||
bfloat16, \
|
||||
bfloat16_t, \
|
||||
otype, \
|
||||
op)
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i_reduce_" #name)]] [[kernel]] void \
|
||||
init_reduce<otype, op>( \
|
||||
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And<bool>)
|
||||
instantiate_init_reduce(orbool_, bool, Or<bool>)
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("allNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or<bool>)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("colGeneral_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("colGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("colSmall_reduce_" #name)]] [[kernel]] void \
|
||||
col_reduce_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or<bool>)
|
||||
|
||||
#define instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("rowGeneralSmall_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint lid [[thread_position_in_grid]]); \
|
||||
template [[host_name("rowGeneralMed_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_med<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("rowGeneral_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_row_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
instantiate_row_reduce_small(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("rowGeneralNoAtomics_reduce_" #name)]] [[kernel]] void \
|
||||
row_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& out_size [[buffer(3)]], \
|
||||
const constant size_t& non_row_reductions [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_same_row_reduce_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_row_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_row_reduce_general_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_row_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, and, bool, And<bool>)
|
||||
instantiate_reduce_from_types(instantiate_row_reduce_general, or, bool, Or<bool>)
|
||||
|
||||
instantiate_row_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
// clang-format on
|
6
mlx/backend/metal/kernels/reduce_utils.h
Normal file
6
mlx/backend/metal/kernels/reduce_utils.h
Normal file
@@ -0,0 +1,6 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
@@ -1,32 +0,0 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Reduce init
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
||||
|
||||
#define instantiate_init_reduce(name, otype, op) \
|
||||
template [[host_name("i" #name)]] [[kernel]] void init_reduce<otype, op>( \
|
||||
device otype * out [[buffer(1)]], uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_init_reduce_helper(name, tname, type, op) \
|
||||
instantiate_init_reduce(name##tname, type, op<type>)
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_init_reduce_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_init_reduce(andbool_, bool, And)
|
||||
instantiate_init_reduce(orbool_, bool, Or) // clang-format on
|
@@ -5,9 +5,7 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
static constant constexpr const uint8_t simd_size = 32;
|
||||
|
||||
union bool4_or_uint {
|
||||
bool4 b;
|
||||
@@ -21,6 +19,7 @@ struct None {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U = bool>
|
||||
struct And {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_all(val);
|
||||
@@ -58,6 +57,7 @@ struct And {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U = bool>
|
||||
struct Or {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_any(val);
|
||||
|
@@ -1,11 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// All reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -139,50 +133,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
out[thread_group_id] = total_val;
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_all_reduce(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_" #name)]] [[kernel]] void \
|
||||
all_reduce<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_all_reduce_no_atomics(name, itype, otype, op) \
|
||||
template [[host_name("all_reduce_no_atomics_" #name)]] [[kernel]] void \
|
||||
all_reduce_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const device size_t& in_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint grid_size [[threads_per_grid]], \
|
||||
uint simd_per_group [[simdgroups_per_threadgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
||||
uint thread_group_id [[threadgroup_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define instantiate_same_all_reduce_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce(name##tname, type, type, op<type>)
|
||||
|
||||
#define instantiate_same_all_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_all_reduce_no_atomics(name##tname, type, type, op<type>)
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_all_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_all_reduce, or, bool, Or)
|
||||
|
||||
// special case bool with larger output type
|
||||
instantiate_all_reduce(sumbool_, bool, uint32_t, Sum<uint32_t>) // clang-format on
|
@@ -1,11 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/kernels/reduction/ops.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/reduce_inst.h"
|
||||
#include "mlx/backend/metal/kernels/reduction/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Small column reduce kernel
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -52,23 +46,6 @@ template <typename T, typename U, typename Op>
|
||||
out[out_idx] = total_val;
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_small(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_small_" #name)]] [[kernel]] void \
|
||||
col_reduce_small<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
const constant size_t& non_col_reductions [[buffer(8)]], \
|
||||
const constant int* non_col_shapes [[buffer(9)]], \
|
||||
const constant size_t* non_col_strides [[buffer(10)]], \
|
||||
const constant int& non_col_ndim [[buffer(11)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Column reduce helper
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -186,64 +163,3 @@ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_col_reduce_general(name, itype, otype, op) \
|
||||
template [[host_name("col_reduce_general_" #name)]] [[kernel]] void \
|
||||
col_reduce_general<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device mlx_atomic<otype>* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 lsize [[threads_per_threadgroup]]);
|
||||
|
||||
#define instantiate_col_reduce_general_no_atomics(name, itype, otype, op) \
|
||||
template \
|
||||
[[host_name("col_reduce_general_no_atomics_" #name)]] [[kernel]] void \
|
||||
col_reduce_general_no_atomics<itype, otype, op>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t& reduction_size [[buffer(2)]], \
|
||||
const constant size_t& reduction_stride [[buffer(3)]], \
|
||||
const constant size_t& out_size [[buffer(4)]], \
|
||||
const constant int* shape [[buffer(5)]], \
|
||||
const constant size_t* strides [[buffer(6)]], \
|
||||
const constant int& ndim [[buffer(7)]], \
|
||||
threadgroup otype* local_data [[threadgroup(0)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint3 lid [[thread_position_in_threadgroup]], \
|
||||
uint3 gid [[thread_position_in_grid]], \
|
||||
uint3 lsize [[threads_per_threadgroup]], \
|
||||
uint3 gsize [[threads_per_grid]]);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Instantiations
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_same_col_reduce_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general(name ##tname, type, type, op<type>) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
#define instantiate_same_col_reduce_na_helper(name, tname, type, op) \
|
||||
instantiate_col_reduce_small(name ##tname, type, type, op<type>) \
|
||||
instantiate_col_reduce_general_no_atomics(name ##tname, type, type, op<type>) // clang-format on
|
||||
|
||||
// clang-format off
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_helper, instantiate_reduce_helper_types)
|
||||
instantiate_reduce_ops(instantiate_same_col_reduce_na_helper, instantiate_reduce_helper_64b)
|
||||
|
||||
instantiate_col_reduce_general(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_general, or, bool, Or)
|
||||
|
||||
instantiate_col_reduce_small(sumbool_, bool, uint32_t, Sum<uint32_t>)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, and, bool, And)
|
||||
instantiate_reduce_from_types(instantiate_col_reduce_small, or, bool, Or) // clang-format on
|
8
mlx/backend/metal/kernels/reduction/reduce_init.h
Normal file
8
mlx/backend/metal/kernels/reduction/reduce_init.h
Normal file
@@ -0,0 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
template <typename T, typename Op>
|
||||
[[kernel]] void init_reduce(
|
||||
device T* out [[buffer(0)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
out[tid] = Op::init;
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user