mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
Compare commits
1 Commits
v0.16.3
...
async_all_
Author | SHA1 | Date | |
---|---|---|---|
![]() |
3a1df968cf |
@@ -144,7 +144,6 @@ jobs:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@<< parameters.python_version >>
|
||||
brew install openmpi
|
||||
python<< parameters.python_version >> -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
|
2
.github/workflows/pull_request.yml
vendored
2
.github/workflows/pull_request.yml
vendored
@@ -17,4 +17,4 @@ jobs:
|
||||
pip install pre-commit black isort clang-format
|
||||
- name: Run lint
|
||||
run: |
|
||||
pre-commit run --all-files
|
||||
pre-commit run --all-files
|
@@ -10,14 +10,13 @@ MLX was developed with contributions from the following individuals:
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
||||
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
||||
|
||||
if(NOT MLX_VERSION)
|
||||
set(MLX_VERSION 0.16.3)
|
||||
set(MLX_VERSION 0.14.0)
|
||||
endif()
|
||||
|
||||
# --------------------- Processor tests -------------------------
|
||||
@@ -83,21 +83,24 @@ elseif (MLX_BUILD_METAL)
|
||||
OUTPUT_VARIABLE MACOS_VERSION
|
||||
COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if (${MACOS_VERSION} LESS 14.0)
|
||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||
endif()
|
||||
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
|
||||
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
|
||||
# Get the metal version
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION
|
||||
COMMAND_ERROR_IS_FATAL ANY)
|
||||
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
|
||||
set(MLX_METAL_VERSION METAL_3_1)
|
||||
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
|
||||
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
|
||||
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
|
||||
set(MLX_METAL_VERSION METAL_3_0)
|
||||
else()
|
||||
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
metal_cpp
|
||||
URL ${METAL_CPP_URL}
|
||||
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
@@ -112,7 +115,7 @@ elseif (MLX_BUILD_METAL)
|
||||
${FOUNDATION_LIB}
|
||||
${QUARTZ_LIB})
|
||||
|
||||
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
|
||||
add_compile_definitions(${MLX_METAL_VERSION})
|
||||
endif()
|
||||
|
||||
if (MLX_BUILD_CPU)
|
||||
@@ -166,26 +169,7 @@ endif()
|
||||
|
||||
find_package(MPI)
|
||||
if (MPI_FOUND)
|
||||
execute_process(
|
||||
COMMAND zsh "-c" "mpirun --version"
|
||||
OUTPUT_VARIABLE MPI_VERSION
|
||||
ERROR_QUIET
|
||||
)
|
||||
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
|
||||
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
|
||||
elseif (MPI_VERSION STREQUAL "")
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING
|
||||
"MPI found but mpirun is not available. Building without MPI."
|
||||
)
|
||||
else()
|
||||
set(MPI_FOUND FALSE)
|
||||
message(
|
||||
WARNING
|
||||
"MPI which is not OpenMPI found. Building without MPI."
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
||||
|
@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
|
||||
def mish(x: torch.Tensor) -> torch.Tensor:
|
||||
y = x
|
||||
for _ in range(100):
|
||||
y = torch.nn.functional.mish(y)
|
||||
return torch.nn.functional.mish(y)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@@ -283,14 +283,6 @@ def topk(axis, x):
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step_function(x):
|
||||
y = x
|
||||
for i in range(100):
|
||||
y = torch.where(y < 0, 0, 1)
|
||||
sync_if_needed(x)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def selu(x):
|
||||
y = x
|
||||
@@ -454,11 +446,5 @@ if __name__ == "__main__":
|
||||
elif args.benchmark == "topk":
|
||||
print(bench(topk, axis, x))
|
||||
|
||||
elif args.benchmark == "step":
|
||||
print(bench(step_function, x))
|
||||
|
||||
elif args.benchmark == "selu":
|
||||
print(bench(selu, x))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
||||
raise ValueError("Unknown benchmark")
|
||||
|
@@ -16,9 +16,7 @@ def run_or_raise(*args, **kwargs):
|
||||
result = run(*args, capture_output=True, **kwargs)
|
||||
return float(result.stdout)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
|
||||
)
|
||||
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
|
||||
|
||||
|
||||
def compare(args):
|
||||
|
@@ -9,6 +9,7 @@ from time_utils import time_fn
|
||||
|
||||
|
||||
def bench_gelu():
|
||||
|
||||
def gelu(x):
|
||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||
|
||||
@@ -50,6 +51,7 @@ def bench_gelu():
|
||||
|
||||
|
||||
def bench_layernorm():
|
||||
|
||||
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
||||
mx.eval(weight, bias)
|
||||
|
@@ -54,6 +54,7 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
|
@@ -1,84 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def timeit(fn, its=100, args=[]):
|
||||
for _ in range(5):
|
||||
fn(*args)
|
||||
tic = time.perf_counter()
|
||||
for _ in range(its):
|
||||
fn(*args)
|
||||
toc = time.perf_counter()
|
||||
return 1e3 * (toc - tic) / its
|
||||
|
||||
|
||||
def time_little_einsum_path():
|
||||
subscripts = "ik,kj->ij"
|
||||
x = mx.ones((32, 32))
|
||||
y = mx.ones((32, 32))
|
||||
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
||||
|
||||
x = np.array(x)
|
||||
y = np.array(y)
|
||||
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
||||
print("Timing little einsum path...")
|
||||
print(f"MLX ... {mx_time:.3f} ms")
|
||||
print(f"NumPy... {np_time:.3f} ms")
|
||||
|
||||
|
||||
def time_big_einsum_path():
|
||||
chars = list("abcdefgh")
|
||||
char_to_dim = {c: v for v, c in enumerate(chars)}
|
||||
|
||||
num_inputs = 10
|
||||
inputs = []
|
||||
subscripts = []
|
||||
for _ in range(num_inputs):
|
||||
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
||||
subscripts.append("".join(subscript))
|
||||
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
||||
subscripts = ",".join(subscripts)
|
||||
|
||||
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
||||
|
||||
inputs = [mx.array(x) for x in inputs]
|
||||
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
||||
print("Timing big einsum path...")
|
||||
print(f"MLX ... {mx_time:.3f} ms")
|
||||
print(f"NumPy... {np_time:.3f} ms")
|
||||
|
||||
|
||||
def time_attention():
|
||||
def regular_attention(x):
|
||||
# shape [batch, sequence, num_heads, head_dim]
|
||||
queries, keys, values = x, x, x
|
||||
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
||||
mx.eval(output)
|
||||
|
||||
def einsum_attention(x):
|
||||
# shape [batch, sequence, num_heads, head_dim]
|
||||
queries, keys, values = x, x, x
|
||||
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
||||
mx.eval(output)
|
||||
|
||||
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
||||
|
||||
regular_time = timeit(regular_attention, args=(x,))
|
||||
ein_time = timeit(einsum_attention, args=(x,))
|
||||
print("Timing einsum attention...")
|
||||
print(f"Regular ... {regular_time:.3f} ms")
|
||||
print(f"Einsum ... {ein_time:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_little_einsum_path()
|
||||
time_big_einsum_path()
|
||||
time_attention()
|
@@ -3,8 +3,6 @@
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import sympy
|
||||
import torch
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
@@ -18,100 +16,41 @@ def bandwidth_gb(runtime_ms, system_size):
|
||||
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
||||
|
||||
|
||||
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
|
||||
def fft_mlx(x):
|
||||
if dim == 1:
|
||||
out = mx.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = mx.fft.fft2(x)
|
||||
def run_bench(system_size):
|
||||
def fft(x):
|
||||
out = mx.fft.fft(x)
|
||||
mx.eval(out)
|
||||
return out
|
||||
|
||||
def fft_mps(x):
|
||||
if dim == 1:
|
||||
out = torch.fft.fft(x)
|
||||
elif dim == 2:
|
||||
out = torch.fft.fft2(x)
|
||||
torch.mps.synchronize()
|
||||
return out
|
||||
|
||||
bandwidths = []
|
||||
for n in fft_sizes:
|
||||
batch_size = system_size // n**dim
|
||||
shape = [batch_size] + [n for _ in range(dim)]
|
||||
if backend == "mlx":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = mx.array(x_np)
|
||||
mx.eval(x)
|
||||
fft = fft_mlx
|
||||
elif backend == "mps":
|
||||
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
||||
x = torch.tensor(x_np, device="mps")
|
||||
torch.mps.synchronize()
|
||||
fft = fft_mps
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
for k in range(4, 12):
|
||||
n = 2**k
|
||||
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
|
||||
x = x.astype(mx.complex64)
|
||||
mx.eval(x)
|
||||
runtime_ms = measure_runtime(fft, x=x)
|
||||
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
|
||||
print(n, bandwidth)
|
||||
bandwidths.append(bandwidth)
|
||||
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
|
||||
|
||||
return np.array(bandwidths)
|
||||
return bandwidths
|
||||
|
||||
|
||||
def time_fft():
|
||||
x = np.array(range(2, 512))
|
||||
system_size = int(2**26)
|
||||
|
||||
print("MLX GPU")
|
||||
with mx.stream(mx.gpu):
|
||||
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
|
||||
print("MPS GPU")
|
||||
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
|
||||
|
||||
print("CPU")
|
||||
system_size = int(2**20)
|
||||
with mx.stream(mx.cpu):
|
||||
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
||||
cpu_bandwidths = run_bench(system_size=int(2**22))
|
||||
|
||||
x = np.array(x)
|
||||
with mx.stream(mx.gpu):
|
||||
gpu_bandwidths = run_bench(system_size=int(2**29))
|
||||
|
||||
all_indices = x - x[0]
|
||||
radix_2to13 = (
|
||||
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
bluesteins = (
|
||||
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
|
||||
)
|
||||
|
||||
for indices, name in [
|
||||
(all_indices, "All"),
|
||||
(radix_2to13, "Radix 2-13"),
|
||||
(bluesteins, "Bluestein's"),
|
||||
]:
|
||||
# plot bandwidths
|
||||
print(name)
|
||||
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
|
||||
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
|
||||
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
|
||||
plt.title(f"MLX FFT Benchmark -- {name}")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig(f"{name}.png")
|
||||
plt.clf()
|
||||
|
||||
av_gpu_bandwidth = np.mean(gpu_bandwidths)
|
||||
av_mps_bandwidth = np.mean(mps_bandwidths)
|
||||
av_cpu_bandwidth = np.mean(cpu_bandwidths)
|
||||
print("Average bandwidths:")
|
||||
print("GPU:", av_gpu_bandwidth)
|
||||
print("MPS:", av_mps_bandwidth)
|
||||
print("CPU:", av_cpu_bandwidth)
|
||||
|
||||
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
|
||||
print("Percent MLX faster than MPS: ", portion_faster * 100)
|
||||
# plot bandwidths
|
||||
x = [2**k for k in range(4, 12)]
|
||||
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
|
||||
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
|
||||
plt.title("MLX FFT Benchmark")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig("fft_plot.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -1,70 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import matplotlib
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from time_utils import measure_runtime
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def had(x):
|
||||
y = mx.hadamard_transform(x)
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def copy(x):
|
||||
y = x + 1.0
|
||||
mx.eval(y)
|
||||
|
||||
|
||||
def run(dtype):
|
||||
system_size = 2**26
|
||||
outputs = {}
|
||||
for test_fn in (had, copy):
|
||||
for m in [1, 12, 20, 28]:
|
||||
if test_fn == copy:
|
||||
key = "copy"
|
||||
elif m == 1:
|
||||
key = "had_2^k"
|
||||
else:
|
||||
key = "had_m*2^k"
|
||||
outputs.setdefault(key, {})
|
||||
for k in range(7, 14):
|
||||
n = m * 2**k
|
||||
if n > 2**15:
|
||||
continue
|
||||
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
|
||||
x = mx.array(x_np)
|
||||
runtime_ms = measure_runtime(test_fn, x=x)
|
||||
bytes_per_gb = 1e9
|
||||
ms_per_s = 1e3
|
||||
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
|
||||
bandwidth_gb = (
|
||||
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
|
||||
)
|
||||
print(n, bandwidth_gb)
|
||||
outputs[key][n] = bandwidth_gb
|
||||
|
||||
colors = {
|
||||
"copy": "black",
|
||||
"had_2^k": "steelblue",
|
||||
"had_m*2^k": "skyblue",
|
||||
}
|
||||
for key, output in outputs.items():
|
||||
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
|
||||
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
|
||||
plt.xlabel("N")
|
||||
plt.ylabel("Bandwidth (GB/s)")
|
||||
plt.legend()
|
||||
plt.savefig(f"bench_{dtype.__name__}.png")
|
||||
plt.clf()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
args = parser.parse_args()
|
||||
dtype = np.float16 if args.fp16 else np.float32
|
||||
run(dtype)
|
@@ -1,62 +0,0 @@
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from time_utils import time_fn
|
||||
|
||||
MAX_SEQ = 300
|
||||
START_SEQ = 100
|
||||
SEQ_INCREMENT = 50
|
||||
|
||||
|
||||
def time_self_attention_primitives():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
def sdpa_primitives(qs, ks, vs, alpha):
|
||||
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
|
||||
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
|
||||
o = p @ vs
|
||||
return o
|
||||
|
||||
time_fn(sdpa_primitives, q, k, v, scale)
|
||||
|
||||
|
||||
def time_self_attention_sdpa():
|
||||
mx.random.seed(3)
|
||||
B = 2
|
||||
H = 38
|
||||
D = 64
|
||||
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
|
||||
q = mx.random.uniform(shape=(B, H, R, D))
|
||||
k = mx.random.uniform(shape=(B, H, R, D))
|
||||
v = mx.random.uniform(shape=(B, H, R, D))
|
||||
scale = 1.0 / math.sqrt(float(D))
|
||||
mx.eval(q, k, v)
|
||||
|
||||
def sdpa_fused(qs, ks, vs, alpha):
|
||||
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
|
||||
return o
|
||||
|
||||
time_fn(sdpa_fused, q, k, v, scale)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MLX benchmarks.")
|
||||
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
||||
args = parser.parse_args()
|
||||
if args.gpu:
|
||||
mx.set_default_device(mx.gpu)
|
||||
else:
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
time_self_attention_sdpa()
|
||||
time_self_attention_primitives()
|
36
cmake/metal.14.0.diff
Normal file
36
cmake/metal.14.0.diff
Normal file
@@ -0,0 +1,36 @@
|
||||
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
|
||||
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
|
||||
@@ -62,6 +62,7 @@
|
||||
|
||||
uint64_t signaledValue() const;
|
||||
void setSignaledValue(uint64_t signaledValue);
|
||||
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||
};
|
||||
|
||||
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||
@@ -138,6 +139,11 @@
|
||||
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||
{
|
||||
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||
+}
|
||||
+
|
||||
+// method: waitUntilSignaledValue
|
||||
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||
}
|
||||
|
||||
// static method: alloc
|
||||
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
|
||||
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
|
||||
@@ -1906,6 +1906,9 @@
|
||||
"setShouldMaximizeConcurrentCompilation:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||
"setSignaledValue:");
|
||||
+_MTL_PRIVATE_DEF_SEL(
|
||||
+ waitUntilSignaledValue_timeoutMS_,
|
||||
+ "waitUntilSignaledValue:timeoutMS:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||
"setSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
36
cmake/metal.14.2.diff
Normal file
36
cmake/metal.14.2.diff
Normal file
@@ -0,0 +1,36 @@
|
||||
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
|
||||
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
|
||||
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
|
||||
@@ -62,6 +62,7 @@
|
||||
|
||||
uint64_t signaledValue() const;
|
||||
void setSignaledValue(uint64_t signaledValue);
|
||||
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
|
||||
};
|
||||
|
||||
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
|
||||
@@ -138,6 +139,11 @@
|
||||
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
|
||||
{
|
||||
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
|
||||
+}
|
||||
+
|
||||
+// method: waitUntilSignaledValue
|
||||
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
|
||||
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
|
||||
}
|
||||
|
||||
// static method: alloc
|
||||
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
|
||||
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
|
||||
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
|
||||
@@ -1918,6 +1918,9 @@
|
||||
"setShouldMaximizeConcurrentCompilation:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
|
||||
"setSignaledValue:");
|
||||
+_MTL_PRIVATE_DEF_SEL(
|
||||
+ waitUntilSignaledValue_timeoutMS_,
|
||||
+ "waitUntilSignaledValue:timeoutMS:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSize_,
|
||||
"setSize:");
|
||||
_MTL_PRIVATE_DEF_SEL(setSlice_,
|
@@ -1,4 +1,3 @@
|
||||
sphinx
|
||||
breathe
|
||||
sphinx-book-theme
|
||||
mlx
|
||||
|
@@ -83,15 +83,3 @@ def setup(app):
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
|
||||
latex_elements = {
|
||||
"preamble": r"""
|
||||
\usepackage{enumitem}
|
||||
\setlistdepth{5}
|
||||
\setlist[itemize,1]{label=$\bullet$}
|
||||
\setlist[itemize,2]{label=$\bullet$}
|
||||
\setlist[itemize,3]{label=$\bullet$}
|
||||
\setlist[itemize,4]{label=$\bullet$}
|
||||
\setlist[itemize,5]{label=$\bullet$}
|
||||
\renewlist{itemize}{itemize}{5}
|
||||
""",
|
||||
}
|
||||
|
@@ -486,8 +486,9 @@ below.
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
@@ -15,7 +15,7 @@ module to concisely define the model architecture.
|
||||
Attention layer
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
We will start with the Llama attention layer which notably uses the RoPE
|
||||
We will start with the llama attention layer which notably uses the RoPE
|
||||
positional encoding. [1]_ In addition, our attention layer will optionally use a
|
||||
key/value cache that will be concatenated with the provided keys and values to
|
||||
support efficient inference.
|
||||
|
@@ -64,7 +64,7 @@ set:
|
||||
Next, setup the problem parameters and load the data. To load the data, you need our
|
||||
`mnist data loader
|
||||
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
|
||||
we will import as ``mnist``.
|
||||
we will import as `mnist`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@@ -43,7 +43,6 @@ are the CPU and GPU.
|
||||
usage/function_transforms
|
||||
usage/compile
|
||||
usage/numpy
|
||||
usage/distributed
|
||||
usage/using_streams
|
||||
|
||||
.. toctree::
|
||||
@@ -70,7 +69,6 @@ are the CPU and GPU.
|
||||
python/metal
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/distributed
|
||||
python/tree_utils
|
||||
|
||||
.. toctree::
|
||||
|
@@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
|
||||
|
||||
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
|
||||
|
||||
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
|
||||
Then simply build and install MLX using pip:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||
|
||||
For developing, install the package with development dependencies, and use an
|
||||
editable install:
|
||||
For developing use an editable install:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e ".[dev]"
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
|
||||
|
||||
Once the development dependencies are installed, you can build faster with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext -j --inplace
|
||||
|
||||
Run the tests with:
|
||||
To make sure the install is working run the tests with:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[testing]"
|
||||
python -m unittest discover python/tests
|
||||
|
||||
Optional: Install stubs to enable auto completions and type checking from your
|
||||
IDE:
|
||||
Optional: Install stubs to enable auto completions and type checking from your IDE:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
pip install ".[dev]"
|
||||
python setup.py generate_stubs
|
||||
|
||||
C++ API
|
||||
@@ -186,8 +186,8 @@ should point to the path to the built metal library.
|
||||
Binary Size Minimization
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
|
||||
and ``BUILD_SHARED_LIBS=ON``.
|
||||
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
|
||||
and `BUILD_SHARED_LIBS=ON`.
|
||||
|
||||
The MLX CMake build has several additional options to make smaller binaries.
|
||||
For example, if you don't need the CPU backend or support for safetensors and
|
||||
@@ -195,7 +195,7 @@ GGUF, you can do:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
cmake .. \
|
||||
cmake ..
|
||||
-DCMAKE_BUILD_TYPE=MinSizeRel \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DMLX_BUILD_CPU=OFF \
|
||||
@@ -203,7 +203,7 @@ GGUF, you can do:
|
||||
-DMLX_BUILD_GGUF=OFF \
|
||||
-DMLX_METAL_JIT=ON
|
||||
|
||||
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
|
||||
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
|
||||
contains pre-built GPU kernels. This substantially reduces the size of the
|
||||
Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
|
@@ -24,7 +24,6 @@ Array
|
||||
array.any
|
||||
array.argmax
|
||||
array.argmin
|
||||
array.conj
|
||||
array.cos
|
||||
array.cummax
|
||||
array.cummin
|
||||
@@ -58,4 +57,3 @@ Array
|
||||
array.transpose
|
||||
array.T
|
||||
array.var
|
||||
array.view
|
||||
|
@@ -1,19 +0,0 @@
|
||||
.. _distributed:
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
Distributed Communication
|
||||
==========================
|
||||
|
||||
MLX provides a distributed communication package using MPI. The MPI library is
|
||||
loaded at runtime; if MPI is available then distributed communication is also
|
||||
made available.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Group
|
||||
is_available
|
||||
init
|
||||
all_sum
|
||||
all_gather
|
@@ -9,9 +9,7 @@ Linear Algebra
|
||||
:toctree: _autosummary
|
||||
|
||||
inv
|
||||
tri_inv
|
||||
norm
|
||||
cholesky
|
||||
cholesky_inv
|
||||
qr
|
||||
svd
|
||||
|
@@ -17,8 +17,6 @@ simple functions.
|
||||
gelu_approx
|
||||
gelu_fast_approx
|
||||
glu
|
||||
hard_shrink
|
||||
hard_tanh
|
||||
hardswish
|
||||
leaky_relu
|
||||
log_sigmoid
|
||||
@@ -31,7 +29,6 @@ simple functions.
|
||||
sigmoid
|
||||
silu
|
||||
softmax
|
||||
softmin
|
||||
softplus
|
||||
softshrink
|
||||
step
|
||||
|
@@ -21,15 +21,10 @@ Layers
|
||||
Dropout3d
|
||||
Embedding
|
||||
GELU
|
||||
GLU
|
||||
GroupNorm
|
||||
GRU
|
||||
HardShrink
|
||||
HardTanh
|
||||
Hardswish
|
||||
InstanceNorm
|
||||
LayerNorm
|
||||
LeakyReLU
|
||||
Linear
|
||||
LSTM
|
||||
MaxPool1d
|
||||
@@ -41,19 +36,13 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
SELU
|
||||
Sequential
|
||||
SiLU
|
||||
SinusoidalPositionalEncoding
|
||||
Softmin
|
||||
Softshrink
|
||||
Softsign
|
||||
Softmax
|
||||
Softplus
|
||||
Step
|
||||
Tanh
|
||||
Transformer
|
||||
Upsample
|
||||
|
@@ -57,8 +57,6 @@ Operations
|
||||
diagonal
|
||||
divide
|
||||
divmod
|
||||
einsum
|
||||
einsum_path
|
||||
equal
|
||||
erf
|
||||
erfinv
|
||||
@@ -74,7 +72,6 @@ Operations
|
||||
gather_qmm
|
||||
greater
|
||||
greater_equal
|
||||
hadamard_transform
|
||||
identity
|
||||
inner
|
||||
isclose
|
||||
@@ -106,7 +103,6 @@ Operations
|
||||
minimum
|
||||
moveaxis
|
||||
multiply
|
||||
nan_to_num
|
||||
negative
|
||||
not_equal
|
||||
ones
|
||||
@@ -160,7 +156,6 @@ Operations
|
||||
tril
|
||||
triu
|
||||
var
|
||||
view
|
||||
where
|
||||
zeros
|
||||
zeros_like
|
||||
|
@@ -31,41 +31,6 @@ model's parameters and the **optimizer state**.
|
||||
# Compute the new parameters but also the optimizer state.
|
||||
mx.eval(model.parameters(), optimizer.state)
|
||||
|
||||
Saving and Loading
|
||||
------------------
|
||||
|
||||
To serialize an optimizer, save its state. To load an optimizer, load and set
|
||||
the saved state. Here's a simple example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
import mlx.optimizers as optim
|
||||
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
# Perform some updates with the optimizer
|
||||
model = {"w" : mx.zeros((5, 5))}
|
||||
grads = {"w" : mx.ones((5, 5))}
|
||||
optimizer.update(model, grads)
|
||||
|
||||
# Save the state
|
||||
state = tree_flatten(optimizer.state)
|
||||
mx.save_safetensors("optimizer.safetensors", dict(state))
|
||||
|
||||
# Later on, for example when loading from a checkpoint,
|
||||
# recreate the optimizer and load the state
|
||||
optimizer = optim.Adam(learning_rate=1e-2)
|
||||
|
||||
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
|
||||
optimizer.state = state
|
||||
|
||||
Note, not every optimizer configuation parameter is saved in the state. For
|
||||
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
|
||||
parameters are not. A good rule of thumb is if the parameter can be scheduled
|
||||
then it will be included in the optimizer state.
|
||||
|
||||
.. toctree::
|
||||
|
||||
optimizers/optimizer
|
||||
|
@@ -44,4 +44,3 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
|
||||
split
|
||||
truncated_normal
|
||||
uniform
|
||||
laplace
|
||||
|
@@ -10,7 +10,6 @@ Transforms
|
||||
|
||||
eval
|
||||
compile
|
||||
custom_function
|
||||
disable_compile
|
||||
enable_compile
|
||||
grad
|
||||
|
@@ -1,166 +0,0 @@
|
||||
.. _usage_distributed:
|
||||
|
||||
Distributed Communication
|
||||
=========================
|
||||
|
||||
.. currentmodule:: mlx.core.distributed
|
||||
|
||||
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
|
||||
provide distributed communication operations that allow the computational cost
|
||||
of training or inference to be shared across many physical machines. You can
|
||||
see a list of the supported operations in the :ref:`API docs<distributed>`.
|
||||
|
||||
.. note::
|
||||
A lot of operations may not be supported or not as fast as they should be.
|
||||
We are adding more and tuning the ones we have as we are figuring out the
|
||||
best way to do distributed computing on Macs using MLX.
|
||||
|
||||
Getting Started
|
||||
---------------
|
||||
|
||||
MLX already comes with the ability to "talk" to MPI if it is installed on the
|
||||
machine. The minimal distributed program in MLX is as simple as:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
world = mx.distributed.init()
|
||||
x = mx.distributed.all_sum(mx.ones(10))
|
||||
print(world.rank(), x)
|
||||
|
||||
The program above sums the array ``mx.ones(10)`` across all
|
||||
distributed processes. If simply run with ``python``, however, only one
|
||||
process is launched and no distributed communication takes place.
|
||||
|
||||
To launch the program in distributed mode we need to use ``mpirun`` or
|
||||
``mpiexec`` depending on the MPI installation. The simplest possible way is the
|
||||
following:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 python test.py
|
||||
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
|
||||
|
||||
The above launches two processes on the same (local) machine and we can see
|
||||
both standard output streams. The processes send the array of 1s to each other
|
||||
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
|
||||
print 4 etc.
|
||||
|
||||
Installing MPI
|
||||
---------------
|
||||
|
||||
MPI can be installed with Homebrew, using the Anaconda package manager or
|
||||
compiled from source. Most of our testing is done using ``openmpi`` installed
|
||||
with the Anaconda package manager as follows:
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ conda install openmpi
|
||||
|
||||
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
|
||||
so that MLX can find it and load it at runtime. This can simply be achieved by
|
||||
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
|
||||
|
||||
.. code:: shell
|
||||
|
||||
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
|
||||
|
||||
Setting up Remote Hosts
|
||||
-----------------------
|
||||
|
||||
MPI can automatically connect to remote hosts and set up the communication over
|
||||
the network if the remote hosts can be accessed via ssh. A good checklist to
|
||||
debug connectivity issues is the following:
|
||||
|
||||
* ``ssh hostname`` works from all machines to all machines without asking for
|
||||
password or host confirmation
|
||||
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
|
||||
full path to force all machines to use a specific path.
|
||||
* Ensure that the ``hostname`` used by MPI is the one that you have configured
|
||||
in the ``.ssh/config`` files on all machines.
|
||||
|
||||
.. note::
|
||||
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
|
||||
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
|
||||
|
||||
An easy way to pass the host names to MPI is using a host file. A host file
|
||||
looks like the following, where ``host1`` and ``host2`` should be the fully
|
||||
qualified domain names or IPs for these hosts.
|
||||
|
||||
.. code::
|
||||
|
||||
host1 slots=1
|
||||
host2 slots=1
|
||||
|
||||
When using MLX, it is very likely that you want to use 1 slot per host, ie one
|
||||
process per host. The hostfile also needs to contain the current
|
||||
host if you want to run on the local host. Passing the host file to
|
||||
``mpirun`` is simply done using the ``--hostfile`` command line argument.
|
||||
|
||||
Training Example
|
||||
----------------
|
||||
|
||||
In this section we will adapt an MLX training loop to support data parallel
|
||||
distributed training. Namely, we will average the gradients across a set of
|
||||
hosts before applying them to the model.
|
||||
|
||||
Our training loop looks like the following code snippet if we omit the model,
|
||||
dataset and optimizer initialization.
|
||||
|
||||
.. code:: python
|
||||
|
||||
model = ...
|
||||
optimizer = ...
|
||||
dataset = ...
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
for x, y in dataset:
|
||||
loss = step(model, x, y)
|
||||
mx.eval(loss, model.parameters())
|
||||
|
||||
All we have to do to average the gradients across machines is perform an
|
||||
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
|
||||
have to :func:`mlx.utils.tree_map` the gradients with following function.
|
||||
|
||||
.. code:: python
|
||||
|
||||
def all_avg(x):
|
||||
return mx.distributed.all_sum(x) / mx.distributed.init().size()
|
||||
|
||||
Putting everything together our training loop step looks as follows with
|
||||
everything else remaining the same.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def all_reduce_grads(grads):
|
||||
N = mx.distributed.init()
|
||||
if N == 1:
|
||||
return grads
|
||||
return tree_map(
|
||||
lambda x: mx.distributed.all_sum(x) / N,
|
||||
grads)
|
||||
|
||||
def step(model, x, y):
|
||||
loss, grads = loss_grad_fn(model, x, y)
|
||||
grads = all_reduce_grads(grads) # <--- This line was added
|
||||
optimizer.update(model, grads)
|
||||
return loss
|
||||
|
||||
Tuning All Reduce
|
||||
-----------------
|
||||
|
||||
We are working on improving the performance of all reduce on MLX but for now
|
||||
the two main things one can do to extract the most out of distributed training with MLX are:
|
||||
|
||||
1. Perform a few large reductions instead of many small ones to improve
|
||||
bandwidth and latency
|
||||
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
|
||||
connections between each host to improve bandwidth
|
@@ -3,11 +3,7 @@
|
||||
Conversion to NumPy and Other Frameworks
|
||||
========================================
|
||||
|
||||
MLX array supports conversion between other frameworks with either:
|
||||
|
||||
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||
|
||||
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||
Let's convert an array to NumPy and back.
|
||||
|
||||
.. code-block:: python
|
||||
|
@@ -16,7 +16,7 @@ int main() {
|
||||
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
|
||||
|
||||
array x = ones({10});
|
||||
array out = distributed::all_sum(x, global_group);
|
||||
array out = distributed::all_reduce_sum(x, global_group);
|
||||
|
||||
std::cout << out << std::endl;
|
||||
}
|
||||
|
@@ -249,8 +249,9 @@ void Axpby::eval_gpu(
|
||||
kname << (contiguous_kernel ? "contiguous_" : "general_");
|
||||
kname << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available
|
||||
d.register_library("mlx_ext");
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
@@ -1,4 +1,4 @@
|
||||
setuptools>=42
|
||||
cmake>=3.24
|
||||
mlx>=0.16.2
|
||||
nanobind==2.0
|
||||
mlx>=0.9.0
|
||||
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
|
||||
|
@@ -6,7 +6,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
|
||||
|
@@ -17,10 +17,6 @@ bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
@@ -106,7 +102,7 @@ void array::eval() {
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return array_desc_->is_tracer && in_tracing() || retain_graph();
|
||||
return array_desc_->is_tracer && in_tracing();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, deleter_t d) {
|
||||
@@ -175,11 +171,10 @@ array::~array() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ignore arrays that might be detached during eval
|
||||
if (status() == array::Status::scheduled) {
|
||||
// Ignore arrays that will be detached
|
||||
if (status() != array::Status::unscheduled) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Break circular reference for non-detached arrays with siblings
|
||||
if (auto n = siblings().size(); n > 0) {
|
||||
bool do_detach = true;
|
||||
@@ -211,7 +206,7 @@ void array::ArrayDesc::init() {
|
||||
strides[i] = size;
|
||||
size *= shape[i];
|
||||
}
|
||||
for (const auto& in : inputs) {
|
||||
for (auto& in : inputs) {
|
||||
is_tracer |= in.is_tracer();
|
||||
}
|
||||
}
|
||||
@@ -236,7 +231,7 @@ array::ArrayDesc::ArrayDesc(
|
||||
|
||||
array::ArrayDesc::~ArrayDesc() {
|
||||
// When an array description is destroyed it will delete a bunch of arrays
|
||||
// that may also destroy their corresponding descriptions and so on and so
|
||||
// that may also destory their corresponding descriptions and so on and so
|
||||
// forth.
|
||||
//
|
||||
// This calls recursively the destructor and can result in stack overflow, we
|
||||
|
60
mlx/array.h
60
mlx/array.h
@@ -73,32 +73,32 @@ class array {
|
||||
this->array_desc_ = other.array_desc_;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
/** The size of the array's datatype in bytes. */
|
||||
size_t itemsize() const {
|
||||
return size_of(dtype());
|
||||
}
|
||||
};
|
||||
|
||||
/** The number of elements in the array. */
|
||||
size_t size() const {
|
||||
return array_desc_->size;
|
||||
}
|
||||
};
|
||||
|
||||
/** The number of bytes in the array. */
|
||||
size_t nbytes() const {
|
||||
return size() * itemsize();
|
||||
}
|
||||
};
|
||||
|
||||
/** The number of dimensions of the array. */
|
||||
size_t ndim() const {
|
||||
return array_desc_->shape.size();
|
||||
}
|
||||
};
|
||||
|
||||
/** The shape of the array as a vector of integers. */
|
||||
const std::vector<int>& shape() const {
|
||||
return array_desc_->shape;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the size of the corresponding dimension.
|
||||
@@ -107,12 +107,12 @@ class array {
|
||||
* bounds checking. */
|
||||
int shape(int dim) const {
|
||||
return shape().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
};
|
||||
|
||||
/** The strides of the array. */
|
||||
const std::vector<size_t>& strides() const {
|
||||
return array_desc_->strides;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the stride of the corresponding dimension.
|
||||
@@ -121,12 +121,12 @@ class array {
|
||||
* bounds checking. */
|
||||
size_t strides(int dim) const {
|
||||
return strides().at(dim < 0 ? dim + ndim() : dim);
|
||||
}
|
||||
};
|
||||
|
||||
/** Get the arrays data type. */
|
||||
Dtype dtype() const {
|
||||
return array_desc_->dtype;
|
||||
}
|
||||
};
|
||||
|
||||
/** Evaluate the array. */
|
||||
void eval();
|
||||
@@ -160,10 +160,10 @@ class array {
|
||||
|
||||
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
|
||||
return a.arr.id() == b.arr.id() && a.idx == b.idx;
|
||||
}
|
||||
};
|
||||
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
|
||||
return !(a == b);
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
const array& arr;
|
||||
@@ -209,7 +209,7 @@ class array {
|
||||
allocator::Buffer buffer;
|
||||
deleter_t d;
|
||||
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
|
||||
: buffer(buffer), d(d) {}
|
||||
: buffer(buffer), d(d) {};
|
||||
// Not copyable
|
||||
Data(const Data& d) = delete;
|
||||
Data& operator=(const Data& d) = delete;
|
||||
@@ -230,41 +230,42 @@ class array {
|
||||
/** The array's primitive. */
|
||||
Primitive& primitive() const {
|
||||
return *(array_desc_->primitive);
|
||||
}
|
||||
};
|
||||
|
||||
/** A shared pointer to the array's primitive. */
|
||||
std::shared_ptr<Primitive>& primitive_ptr() const {
|
||||
return array_desc_->primitive;
|
||||
}
|
||||
};
|
||||
|
||||
/** Check if the array has an attached primitive or is a leaf node. */
|
||||
bool has_primitive() const {
|
||||
return array_desc_->primitive != nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
/** The array's inputs. */
|
||||
const std::vector<array>& inputs() const {
|
||||
return array_desc_->inputs;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<array>& inputs() {
|
||||
return array_desc_->inputs;
|
||||
}
|
||||
|
||||
/** True indicates the arrays buffer is safe to reuse */
|
||||
bool is_donatable() const {
|
||||
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
|
||||
bool is_donatable(int known_instances = 1) const {
|
||||
return array_desc_.use_count() == known_instances &&
|
||||
(array_desc_->data.use_count() == 1);
|
||||
}
|
||||
|
||||
/** The array's siblings. */
|
||||
const std::vector<array>& siblings() const {
|
||||
return array_desc_->siblings;
|
||||
}
|
||||
};
|
||||
|
||||
/** The array's siblings. */
|
||||
std::vector<array>& siblings() {
|
||||
return array_desc_->siblings;
|
||||
}
|
||||
};
|
||||
|
||||
void set_siblings(std::vector<array> siblings, uint16_t position) {
|
||||
array_desc_->siblings = std::move(siblings);
|
||||
@@ -281,7 +282,7 @@ class array {
|
||||
outputs.push_back(*this);
|
||||
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
|
||||
return outputs;
|
||||
}
|
||||
};
|
||||
|
||||
/** Detach the array from the graph. */
|
||||
void detach();
|
||||
@@ -289,19 +290,19 @@ class array {
|
||||
/** Get the Flags bit-field. */
|
||||
const Flags& flags() const {
|
||||
return array_desc_->flags;
|
||||
}
|
||||
};
|
||||
|
||||
/** The size (in elements) of the underlying buffer the array points to. */
|
||||
size_t data_size() const {
|
||||
return array_desc_->data_size;
|
||||
}
|
||||
};
|
||||
|
||||
allocator::Buffer& buffer() {
|
||||
return array_desc_->data->buffer;
|
||||
}
|
||||
};
|
||||
const allocator::Buffer& buffer() const {
|
||||
return array_desc_->data->buffer;
|
||||
}
|
||||
};
|
||||
|
||||
// Return a copy of the shared pointer
|
||||
// to the array::Data struct
|
||||
@@ -312,20 +313,19 @@ class array {
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return static_cast<T*>(array_desc_->data_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
enum Status { unscheduled, scheduled, available };
|
||||
|
||||
bool is_available() const {
|
||||
return status() == Status::available;
|
||||
}
|
||||
|
||||
Status status() const {
|
||||
const Status status() const {
|
||||
return array_desc_->status;
|
||||
}
|
||||
|
||||
|
@@ -1,9 +1,9 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
#include <vecLib/vDSP.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@@ -2,7 +2,8 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <vecLib/BNNS/bnns.h>
|
||||
#include <vecLib/cblas_new.h>
|
||||
|
||||
#include "mlx/backend/accelerate/utils.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
|
@@ -3,7 +3,8 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <vecLib/vDSP.h>
|
||||
#include <vecLib/vForce.h>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
@@ -36,7 +37,7 @@ DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT_MULTI(DivMod)
|
||||
DEFAULT(NumberOfElements)
|
||||
@@ -50,7 +51,6 @@ DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
@@ -102,7 +102,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -117,7 +117,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -132,7 +132,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
binary(a, b, out, [](auto x, auto y) { return x + y; });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,7 +287,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -300,7 +300,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -315,7 +315,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
binary(a, b, out, [](auto x, auto y) { return x / y; });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,8 +326,12 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
auto size = in.data_size();
|
||||
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::exp(x); });
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
throw std::invalid_argument(
|
||||
"[exp] Cannot exponentiate elements in array"
|
||||
" with non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -389,8 +393,12 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vvlog1pf(
|
||||
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
|
||||
} else if (issubdtype(out.dtype(), inexact)) {
|
||||
unary_fp(in, out, [](auto x) { return std::log1p(x); });
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
throw std::invalid_argument(
|
||||
"[log1p] Cannot compute log of elements in array with"
|
||||
" non floating point type.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -400,7 +408,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -415,7 +423,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
binary(a, b, out, [](auto x, auto y) { return x * y; });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,7 +434,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
set_unary_output_data(in, out);
|
||||
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
unary(in, out, [](auto x) { return -x; });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -513,7 +521,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto size = in.data_size();
|
||||
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
unary(in, out, [](auto x) { return x * x; });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -539,7 +547,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& b = inputs[1];
|
||||
|
||||
if (a.dtype() == float32) {
|
||||
binary_op<float>(
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -557,7 +565,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
|
||||
});
|
||||
} else if (a.dtype() == int32) {
|
||||
binary_op<int>(
|
||||
binary(
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
@@ -569,7 +577,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
},
|
||||
UseDefaultBinaryOp());
|
||||
} else {
|
||||
eval(inputs, out);
|
||||
binary(a, b, out, [](auto x, auto y) { return x - y; });
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -2,8 +2,8 @@
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <simd/vector.h>
|
||||
#include <vecLib/vDSP.h>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
@@ -3,10 +3,7 @@
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#include <simd/math.h>
|
||||
#include <simd/vector.h>
|
||||
|
||||
@@ -56,26 +53,25 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
|
||||
return (*(simd_float16*)&epart) * x;
|
||||
}
|
||||
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
/**
|
||||
* The ARM neon equivalent of the fast exp above.
|
||||
*/
|
||||
inline float16x8_t neon_fast_exp(float16x8_t x) {
|
||||
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
|
||||
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
|
||||
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
|
||||
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
|
||||
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
|
||||
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
|
||||
float16x8_t fpart = vsubq_f16(x, ipart);
|
||||
|
||||
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
|
||||
x = vdupq_n_f16(1.535336188319500e-4f);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
|
||||
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
|
||||
|
||||
// generate 2**ipart in the floating point representation using integer
|
||||
// bitshifting
|
||||
@@ -111,55 +107,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
|
||||
return vget_lane_f16(y, 0);
|
||||
}
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct NeonFp16SimdOps {
|
||||
VT init(T a) {
|
||||
return vdupq_n_f16(a);
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return vld1q_f16(a);
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
vst1q_f16(dst, x);
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return vmaxq_f16(a, b);
|
||||
}
|
||||
|
||||
VT exp(VT x) {
|
||||
return neon_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return vaddq_f16(a, b);
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return vsubq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return vmulq_f16(a, b);
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return vmulq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return neon_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return neon_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct AccelerateSimdOps {
|
||||
VT init(T a) {
|
||||
@@ -176,7 +123,7 @@ struct AccelerateSimdOps {
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return simd_max(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
VT exp(VT x) {
|
||||
return simd_fast_exp(x);
|
||||
@@ -207,6 +154,53 @@ struct AccelerateSimdOps {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename VT>
|
||||
struct NeonFp16SimdOps {
|
||||
VT init(T a) {
|
||||
return vdupq_n_f16(a);
|
||||
}
|
||||
|
||||
VT load(const T* a) {
|
||||
return vld1q_f16(a);
|
||||
}
|
||||
|
||||
void store(T* dst, VT x) {
|
||||
vst1q_f16(dst, x);
|
||||
}
|
||||
|
||||
VT max(VT a, VT b) {
|
||||
return vmaxq_f16(a, b);
|
||||
};
|
||||
|
||||
VT exp(VT x) {
|
||||
return neon_fast_exp(x);
|
||||
}
|
||||
|
||||
VT add(VT a, VT b) {
|
||||
return vaddq_f16(a, b);
|
||||
}
|
||||
|
||||
VT sub(VT a, T b) {
|
||||
return vsubq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
VT mul(VT a, VT b) {
|
||||
return vmulq_f16(a, b);
|
||||
}
|
||||
|
||||
VT mul(VT a, T b) {
|
||||
return vmulq_f16(a, vdupq_n_f16(b));
|
||||
}
|
||||
|
||||
T reduce_max(VT x) {
|
||||
return neon_reduce_max(x);
|
||||
}
|
||||
|
||||
T reduce_add(VT x) {
|
||||
return neon_reduce_add(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename AccT, typename VT, typename Ops, int N>
|
||||
void softmax(const array& in, array& out) {
|
||||
Ops ops;
|
||||
@@ -368,16 +362,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
AccelerateSimdOps<float, simd_float16>,
|
||||
16>(in, out);
|
||||
} else {
|
||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
softmax<
|
||||
float16_t,
|
||||
float16_t,
|
||||
float16x8_t,
|
||||
NeonFp16SimdOps<float16_t, float16x8_t>,
|
||||
8>(in, out);
|
||||
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
eval(inputs, out); // Redirect to common backend for consistency
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
}
|
||||
break;
|
||||
case bfloat16:
|
||||
|
@@ -1,8 +1,8 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#include <vecLib/BNNS/bnns.h>
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
@@ -42,15 +42,12 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
|
@@ -196,20 +196,6 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Maximum::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
|
@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
out.copy_shared_buffer(inputs[0]);
|
||||
}
|
||||
|
||||
void CustomTransforms::eval(
|
||||
void CustomVJP::eval(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() > outputs.size());
|
||||
@@ -250,6 +250,49 @@ void Split::eval(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices_[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides_[i];
|
||||
|
||||
copy_needed |= strides_[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void Slice::shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
||||
const array& in) {
|
||||
int64_t data_offset = 0;
|
||||
|
@@ -205,8 +205,8 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Correct size
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
|
||||
in.is_donatable() &&
|
||||
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
|
||||
if (move_buffers) {
|
||||
outputs[o].move_shared_buffer(
|
||||
|
@@ -4,7 +4,6 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -143,31 +142,29 @@ void copy_general(
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
int64_t i_offset) {
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector<std::vector<stride_t>>{i_strides});
|
||||
switch (new_shape.size()) {
|
||||
switch (src.ndim()) {
|
||||
case 1:
|
||||
copy_general_dim1<SrcT, DstT, stride_t>(
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_dim2<SrcT, DstT, stride_t>(
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_dim3<SrcT, DstT, stride_t>(
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_dim4<SrcT, DstT, stride_t>(
|
||||
src, dst, new_shape, new_strides[0], i_offset);
|
||||
src, dst, data_shape, i_strides, i_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
auto src_ptr = src.data<SrcT>() + i_offset;
|
||||
auto dst_ptr = dst.data<DstT>();
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]);
|
||||
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
|
||||
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
|
||||
}
|
||||
}
|
||||
@@ -198,10 +195,10 @@ inline void copy_general_general_dims(
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
stride_t i_offset,
|
||||
stride_t o_offset) {
|
||||
if constexpr (D > 1) {
|
||||
int axis = data_shape.size() - D;
|
||||
int axis = src.ndim() - D;
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = data_shape[axis];
|
||||
@@ -212,7 +209,7 @@ inline void copy_general_general_dims(
|
||||
o_offset += stride_dst;
|
||||
}
|
||||
} else {
|
||||
int axis = data_shape.size() - 1;
|
||||
int axis = src.ndim() - 1;
|
||||
auto stride_src = i_strides[axis];
|
||||
auto stride_dst = o_strides[axis];
|
||||
auto N = data_shape[axis];
|
||||
@@ -233,76 +230,38 @@ void copy_general_general(
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<stride_t>& i_strides,
|
||||
const std::vector<stride_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset) {
|
||||
auto [new_shape, new_strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector<std::vector<stride_t>>{i_strides, o_strides});
|
||||
switch (new_shape.size()) {
|
||||
stride_t i_offset,
|
||||
stride_t o_offset) {
|
||||
switch (src.ndim()) {
|
||||
case 1:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
i_offset,
|
||||
o_offset);
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 2:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
i_offset,
|
||||
o_offset);
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 3:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
i_offset,
|
||||
o_offset);
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 4:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
i_offset,
|
||||
o_offset);
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
case 5:
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
i_offset,
|
||||
o_offset);
|
||||
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
|
||||
return;
|
||||
}
|
||||
|
||||
int size = std::accumulate(
|
||||
new_shape.end() - 5, new_shape.end(), 1, std::multiplies<int>());
|
||||
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
|
||||
for (int i = 0; i < src.size(); i += size) {
|
||||
stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]);
|
||||
stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]);
|
||||
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
|
||||
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
|
||||
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
|
||||
src,
|
||||
dst,
|
||||
new_shape,
|
||||
new_strides[0],
|
||||
new_strides[1],
|
||||
src_offset,
|
||||
dst_offset);
|
||||
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -485,17 +444,8 @@ void copy_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
template void copy_inplace<size_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
const std::vector<size_t>& i_strides,
|
||||
const std::vector<size_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
|
||||
template void copy_inplace<int64_t>(
|
||||
template <>
|
||||
void copy_inplace<int64_t>(
|
||||
const array& src,
|
||||
array& dst,
|
||||
const std::vector<int>& data_shape,
|
||||
@@ -503,6 +453,24 @@ template void copy_inplace<int64_t>(
|
||||
const std::vector<int64_t>& o_strides,
|
||||
int64_t i_offset,
|
||||
int64_t o_offset,
|
||||
CopyType ctype);
|
||||
CopyType ctype) {
|
||||
switch (ctype) {
|
||||
case CopyType::General:
|
||||
case CopyType::GeneralGeneral:
|
||||
return copy_inplace_dispatch(
|
||||
src,
|
||||
dst,
|
||||
ctype,
|
||||
data_shape,
|
||||
i_strides,
|
||||
o_strides,
|
||||
i_offset,
|
||||
o_offset);
|
||||
|
||||
case CopyType::Scalar:
|
||||
case CopyType::Vector:
|
||||
return copy_inplace_dispatch(src, dst, ctype);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#else
|
||||
#include <cblas.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@@ -52,7 +53,7 @@ DEFAULT(Convolution)
|
||||
DEFAULT(Copy)
|
||||
DEFAULT(Cos)
|
||||
DEFAULT(Cosh)
|
||||
DEFAULT_MULTI(CustomTransforms)
|
||||
DEFAULT_MULTI(CustomVJP)
|
||||
DEFAULT_MULTI(Depends)
|
||||
DEFAULT(Divide)
|
||||
DEFAULT(NumberOfElements)
|
||||
@@ -68,7 +69,6 @@ DEFAULT(Full)
|
||||
DEFAULT(Gather)
|
||||
DEFAULT(Greater)
|
||||
DEFAULT(GreaterEqual)
|
||||
DEFAULT(Hadamard)
|
||||
DEFAULT(Less)
|
||||
DEFAULT(LessEqual)
|
||||
DEFAULT(Load)
|
||||
|
@@ -1,107 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// n = 2^k component
|
||||
template <typename T>
|
||||
void hadamard_n(array& out, int n, int m, float scale) {
|
||||
for (int b = 0; b < out.size() / n; b++) {
|
||||
size_t loc = b * n;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
int h = 1;
|
||||
int n_over_2 = n / 2;
|
||||
while (h < n) {
|
||||
for (int i = 0; i < n / 2; i++) {
|
||||
int k = i & (h - 1);
|
||||
int j = ((i - k) << 1) + k;
|
||||
float x = *(data_ptr + j);
|
||||
float y = *(data_ptr + j + h);
|
||||
*(data_ptr + j) = x + y;
|
||||
*(data_ptr + j + h) = x - y;
|
||||
if (h == n_over_2) {
|
||||
*(data_ptr + j) *= scale;
|
||||
*(data_ptr + j + h) *= scale;
|
||||
}
|
||||
}
|
||||
h <<= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// m component
|
||||
template <typename T>
|
||||
void hadamard_m(array& out, int n, int m, float scale) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
auto start = 1;
|
||||
auto end = matrix.find('\n', start);
|
||||
std::vector<bool> hmat_vec;
|
||||
while (end != std::string_view::npos) {
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
hmat_vec.push_back(row[i] == '+');
|
||||
}
|
||||
start = end + 1;
|
||||
end = matrix.find('\n', start);
|
||||
}
|
||||
|
||||
for (int b = 0; b < out.size() / m / n; b++) {
|
||||
size_t loc = b * n * m;
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
for (int i = 0; i < n; i++) {
|
||||
std::vector<float> out(m);
|
||||
for (int j = 0; j < m; j++) {
|
||||
for (int k = 0; k < m; k++) {
|
||||
float x = *(data_ptr + i + k * n);
|
||||
if (hmat_vec[k + j * m]) {
|
||||
out[j] += x;
|
||||
} else {
|
||||
out[j] -= x;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < m; j++) {
|
||||
*(data_ptr + i + j * n) = out[j] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void hadamard(array& out, int n, int m, float scale) {
|
||||
float n_scale = m > 1 ? 1.0 : scale;
|
||||
hadamard_n<T>(out, n, m, n_scale);
|
||||
if (m > 1) {
|
||||
hadamard_m<T>(out, n, m, scale);
|
||||
}
|
||||
}
|
||||
|
||||
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Copy input to output
|
||||
copy(in, out, CopyType::General);
|
||||
|
||||
int axis = out.ndim() - 1;
|
||||
auto [n, m] = decompose_hadamard(out.shape(axis));
|
||||
|
||||
switch (in.dtype()) {
|
||||
case float32:
|
||||
return hadamard<float>(out, n, m, scale_);
|
||||
case float16:
|
||||
return hadamard<float16_t>(out, n, m, scale_);
|
||||
case bfloat16:
|
||||
return hadamard<bfloat16_t>(out, n, m, scale_);
|
||||
default:
|
||||
throw std::invalid_argument("[hadamard] Unsupported type.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,105 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// From http://neilsloane.com/hadamard/
|
||||
constexpr std::string_view h12 = R"(
|
||||
+-++++++++++
|
||||
--+-+-+-+-+-
|
||||
+++-++----++
|
||||
+---+--+-++-
|
||||
+++++-++----
|
||||
+-+---+--+-+
|
||||
++--+++-++--
|
||||
+--++---+--+
|
||||
++----+++-++
|
||||
+--+-++---+-
|
||||
++++----+++-
|
||||
+-+--+-++---
|
||||
)";
|
||||
|
||||
constexpr std::string_view h20 = R"(
|
||||
+----+----++--++-++-
|
||||
-+----+---+++---+-++
|
||||
--+----+---+++-+-+-+
|
||||
---+----+---+++++-+-
|
||||
----+----++--++-++-+
|
||||
-+++++-----+--+++--+
|
||||
+-+++-+---+-+--+++--
|
||||
++-++--+---+-+--+++-
|
||||
+++-+---+---+-+--+++
|
||||
++++-----++--+-+--++
|
||||
--++-+-++-+-----++++
|
||||
---++-+-++-+---+-+++
|
||||
+---++-+-+--+--++-++
|
||||
++---++-+----+-+++-+
|
||||
-++---++-+----+++++-
|
||||
-+--+--++-+----+----
|
||||
+-+-----++-+----+---
|
||||
-+-+-+---+--+----+--
|
||||
--+-+++------+----+-
|
||||
+--+--++------+----+
|
||||
)";
|
||||
|
||||
constexpr std::string_view h28 = R"(
|
||||
+------++----++-+--+-+--++--
|
||||
-+-----+++-----+-+--+-+--++-
|
||||
--+-----+++---+-+-+----+--++
|
||||
---+-----+++---+-+-+-+--+--+
|
||||
----+-----+++---+-+-+++--+--
|
||||
-----+-----++++--+-+--++--+-
|
||||
------++----++-+--+-+--++--+
|
||||
--++++-+-------++--+++-+--+-
|
||||
---++++-+-----+-++--+-+-+--+
|
||||
+---+++--+----++-++--+-+-+--
|
||||
++---++---+----++-++--+-+-+-
|
||||
+++---+----+----++-++--+-+-+
|
||||
++++--------+-+--++-++--+-+-
|
||||
-++++--------+++--++--+--+-+
|
||||
-+-++-++--++--+--------++++-
|
||||
+-+-++--+--++--+--------++++
|
||||
-+-+-++--+--++--+----+---+++
|
||||
+-+-+-++--+--+---+---++---++
|
||||
++-+-+-++--+------+--+++---+
|
||||
-++-+-+-++--+------+-++++---
|
||||
+-++-+---++--+------+-++++--
|
||||
-++--++-+-++-+++----++------
|
||||
+-++--++-+-++-+++-----+-----
|
||||
++-++---+-+-++-+++-----+----
|
||||
-++-++-+-+-+-+--+++-----+---
|
||||
--++-++++-+-+----+++-----+--
|
||||
+--++-+-++-+-+----+++-----+-
|
||||
++--++-+-++-+-+----++------+
|
||||
)";
|
||||
|
||||
inline const std::map<int, std::string_view> hadamard_matrices() {
|
||||
return {{12, h12}, {20, h20}, {28, h28}};
|
||||
}
|
||||
|
||||
inline std::pair<int, int> decompose_hadamard(int n) {
|
||||
// n = m*2^k
|
||||
int m = 1;
|
||||
if (!is_power_of_2(n)) {
|
||||
auto h_matrices = hadamard_matrices();
|
||||
for (auto [factor, _] : h_matrices) {
|
||||
if (n % factor == 0) {
|
||||
m = factor;
|
||||
n /= factor;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (m == 1) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
|
||||
}
|
||||
}
|
||||
return {n, m};
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -10,106 +10,9 @@
|
||||
#include <lapack.h>
|
||||
#endif
|
||||
|
||||
// Wrapper to account for differences in
|
||||
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
|
||||
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
|
||||
int info;
|
||||
|
||||
#ifdef LAPACK_FORTRAN_STRLEN_END
|
||||
strtri_(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info,
|
||||
/* uplo_len = */ static_cast<size_t>(1),
|
||||
/* diag_len = */ static_cast<size_t>(1));
|
||||
#else
|
||||
strtri_(
|
||||
/* uplo = */ &uplo,
|
||||
/* diag = */ &diag,
|
||||
/* N = */ &N,
|
||||
/* a = */ matrix,
|
||||
/* lda = */ &N,
|
||||
/* info = */ &info);
|
||||
#endif
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void general_inv(array& inv, int N, int i) {
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ nullptr,
|
||||
/* work = */ &workspace_size,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
void tri_inv(array& inv, int N, int i, bool upper) {
|
||||
const char uplo = upper ? 'L' : 'U';
|
||||
const char diag = 'N';
|
||||
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: triangular inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
void inverse_impl(const array& a, array& inv) {
|
||||
// Lapack uses the column-major convention. We take advantage of the following
|
||||
// identity to avoid transposing (see
|
||||
// https://math.stackexchange.com/a/340234):
|
||||
@@ -121,11 +24,63 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
|
||||
const int N = a.shape(-1);
|
||||
const size_t num_matrices = a.size() / (N * N);
|
||||
|
||||
int info;
|
||||
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
|
||||
|
||||
for (int i = 0; i < num_matrices; i++) {
|
||||
if (tri) {
|
||||
tri_inv(inv, N, i, upper);
|
||||
} else {
|
||||
general_inv(inv, N, i);
|
||||
// Compute LU factorization.
|
||||
sgetrf_(
|
||||
/* m = */ &N,
|
||||
/* n = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU factorization failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
static const int lwork_query = -1;
|
||||
float workspace_size = 0;
|
||||
|
||||
// Compute workspace size.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ nullptr,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ nullptr,
|
||||
/* work = */ &workspace_size,
|
||||
/* lwork = */ &lwork_query,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: LU workspace calculation failed with error code "
|
||||
<< info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
const int lwork = workspace_size;
|
||||
auto scratch =
|
||||
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
|
||||
|
||||
// Compute inverse.
|
||||
sgetri_(
|
||||
/* m = */ &N,
|
||||
/* a = */ inv.data<float>() + N * N * i,
|
||||
/* lda = */ &N,
|
||||
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
|
||||
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
|
||||
/* lwork = */ &lwork,
|
||||
/* info = */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "inverse_impl: inversion failed with error code " << info;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -134,7 +89,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
|
||||
if (inputs[0].dtype() != float32) {
|
||||
throw std::runtime_error("[Inverse::eval] only supports float32.");
|
||||
}
|
||||
inverse_impl(inputs[0], output, tri_, upper_);
|
||||
inverse_impl(inputs[0], output);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -28,7 +28,6 @@ const char* get_kernel_preamble() {
|
||||
return R"preamble(
|
||||
$INCLUDES
|
||||
$CONTENT
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::detail;
|
||||
)preamble";
|
||||
}
|
||||
|
@@ -108,105 +108,105 @@ struct Abs {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::abs(x);
|
||||
}
|
||||
};
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acos(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcCosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::acosh(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcSin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asin(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcSinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::asinh(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atan(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTan2 {
|
||||
template <typename T>
|
||||
T operator()(T y, T x) {
|
||||
return std::atan2(y, x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct ArcTanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::atanh(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Ceil {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::ceil(x);
|
||||
}
|
||||
};
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Conjugate {
|
||||
@@ -219,35 +219,35 @@ struct Cos {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cos(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Cosh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::cosh(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Erf {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erf(static_cast<float>(x)));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct ErfInv {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Exp {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return fast_exp(x);
|
||||
}
|
||||
};
|
||||
|
||||
complex64_t operator()(complex64_t x) {
|
||||
return std::exp(x);
|
||||
@@ -258,83 +258,83 @@ struct Expm1 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return expm1(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Floor {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::floor(x);
|
||||
}
|
||||
};
|
||||
int8_t operator()(int8_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
int16_t operator()(int16_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
int32_t operator()(int32_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
int64_t operator()(int64_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint8_t operator()(uint8_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint16_t operator()(uint16_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint32_t operator()(uint32_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
uint64_t operator()(uint64_t x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
bool operator()(bool x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Log {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Log2 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log2(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Log10 {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::log10(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return log1p(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalNot {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return !x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Negative {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return -x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Round {
|
||||
@@ -379,49 +379,49 @@ struct Sin {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sin(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Sinh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sinh(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Square {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return x * x;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Sqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::sqrt(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Rsqrt {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Tan {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tan(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Tanh {
|
||||
template <typename T>
|
||||
T operator()(T x) {
|
||||
return std::tanh(x);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Add {
|
||||
@@ -554,7 +554,7 @@ struct LogAddExp {
|
||||
? maxval
|
||||
: static_cast<decltype(x)>(
|
||||
maxval + std::log1p(fast_exp(minval - maxval)));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Multiply {
|
||||
@@ -602,14 +602,14 @@ struct LogicalAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x && y;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct LogicalOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x || y;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct Select {
|
||||
@@ -623,35 +623,35 @@ struct BitwiseAnd {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x & y;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseOr {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x | y;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct BitwiseXor {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x ^ y;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct LeftShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x << y;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct RightShift {
|
||||
template <typename T>
|
||||
T operator()(T x, T y) {
|
||||
return x >> y;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
@@ -8,9 +8,9 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/arange.h"
|
||||
#include "mlx/backend/common/binary.h"
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/ops.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/threefry.h"
|
||||
#include "mlx/backend/common/unary.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
@@ -313,6 +313,20 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
|
||||
unary(in, out, detail::LogicalNot());
|
||||
}
|
||||
|
||||
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalAnd());
|
||||
}
|
||||
|
||||
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2); // LogicalOr requires two input arrays
|
||||
auto& in1 = inputs[0];
|
||||
auto& in2 = inputs[1];
|
||||
binary(in1, in2, out, detail::LogicalOr());
|
||||
}
|
||||
|
||||
void Negative::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
@@ -405,17 +419,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto out_strides = make_contiguous_strides<size_t>(in.shape());
|
||||
copy_inplace<size_t>(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
out_strides,
|
||||
0,
|
||||
0,
|
||||
CopyType::General);
|
||||
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
@@ -488,8 +492,7 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Calculate out strides, initial offset and if copy needs to be made
|
||||
auto [copy_needed, data_offset, inp_strides] =
|
||||
prepare_slice(in, start_indices_, strides_);
|
||||
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
|
||||
|
||||
// Do copy if needed
|
||||
if (copy_needed) {
|
||||
@@ -587,36 +590,4 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
}
|
||||
|
||||
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
auto obytes = size_of(out.dtype());
|
||||
// Conditions for buffer copying (disjunction):
|
||||
// - type size is the same
|
||||
// - type size is smaller and the last axis is contiguous
|
||||
// - the entire array is row contiguous
|
||||
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
|
||||
in.flags().row_contiguous) {
|
||||
auto strides = in.strides();
|
||||
for (int i = 0; i < strides.size() - 1; ++i) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * obytes / ibytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
copy_inplace(in, tmp, CopyType::General);
|
||||
|
||||
auto flags = out.flags();
|
||||
flags.contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -104,14 +104,48 @@ void reduce_dispatch_out(
|
||||
}
|
||||
case Reduce::Sum: {
|
||||
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
|
||||
if (out.dtype() == int32) {
|
||||
// special case since the input type can be bool
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
} else {
|
||||
reduction_op<InT, InT>(in, out, axes, 0, op);
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
reduction_op<InT, bool>(in, out, axes, false, op);
|
||||
break;
|
||||
case uint8:
|
||||
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint16:
|
||||
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint32:
|
||||
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case uint64:
|
||||
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int8:
|
||||
reduction_op<InT, int8_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int16:
|
||||
reduction_op<InT, int16_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int32:
|
||||
reduction_op<InT, int32_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case int64:
|
||||
reduction_op<InT, int64_t>(in, out, axes, 0, op);
|
||||
break;
|
||||
case float16:
|
||||
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case float32:
|
||||
reduction_op<InT, float>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case bfloat16:
|
||||
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
|
||||
break;
|
||||
case complex64:
|
||||
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
} break;
|
||||
case Reduce::Prod: {
|
||||
auto op = [](auto y, auto x) { (*y) *= x; };
|
||||
reduction_op<InT, InT>(in, out, axes, 1, op);
|
||||
@@ -134,29 +168,6 @@ void reduce_dispatch_out(
|
||||
|
||||
} // namespace
|
||||
|
||||
void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
void Reduce::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
@@ -49,18 +49,47 @@ struct ReductionPlan {
|
||||
ReductionPlan(ReductionOpType type_) : type(type_) {}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
|
||||
namespace {
|
||||
|
||||
// Helper for the ndimensional strided loop
|
||||
// Should this be in utils?
|
||||
void nd_loop(
|
||||
inline void nd_loop(
|
||||
std::function<void(int)> callback,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides);
|
||||
const std::vector<size_t>& strides) {
|
||||
std::function<void(int, int)> loop_inner;
|
||||
loop_inner = [&](int dim, int offset) {
|
||||
if (dim < shape.size() - 1) {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
loop_inner(dim + 1, offset + i * stride);
|
||||
}
|
||||
} else {
|
||||
int size = shape[dim];
|
||||
size_t stride = strides[dim];
|
||||
for (int i = 0; i < size; i++) {
|
||||
callback(offset + i * stride);
|
||||
}
|
||||
}
|
||||
};
|
||||
loop_inner(0, 0);
|
||||
}
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes);
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<int> shape = x.shape();
|
||||
std::vector<size_t> strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
strides.erase(strides.begin() + a);
|
||||
}
|
||||
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
struct DefaultStridedReduce {
|
||||
@@ -94,6 +123,102 @@ struct DefaultContiguousReduce {
|
||||
}
|
||||
};
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1]) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
shape.push_back(x.shape(axes[i]));
|
||||
strides.push_back(x.strides()[axes[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
// Let's check if we can optimize our access patterns
|
||||
//
|
||||
// 1. We have a reduction axis with stride 1. Simply call
|
||||
// GeneralContiguousReduce and be done with it.
|
||||
// 2. We have transpositions and we are not reducing over the axis with
|
||||
// stride 1. However, we are reducing over an axis where everything is
|
||||
// contiguous in memory to the right of that axis. We can call strided
|
||||
// reduce and be done with it.
|
||||
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||
// output, then call strided reduce.
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
return a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||
auto a = reductions[i];
|
||||
auto b = reductions[i - 1];
|
||||
|
||||
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||
if (b.second == a.first * a.second) {
|
||||
reductions.erase(reductions.begin() + i);
|
||||
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
}
|
||||
|
||||
// We can call the contiguous reduction op for every weird way the input is
|
||||
// structured in the rest of the axes.
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||
}
|
||||
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
if (x.strides()[i] != size) {
|
||||
break;
|
||||
}
|
||||
size *= x.shape(i);
|
||||
}
|
||||
if (size >= strides.back()) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
return ReductionPlan(GeneralReduce, shape, strides);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename OpS, typename OpC, typename Op>
|
||||
void reduction_op(
|
||||
const array& x,
|
||||
@@ -236,4 +361,6 @@ void reduction_op(
|
||||
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,118 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
|
||||
const array& x,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<int> shape = x.shape();
|
||||
std::vector<size_t> strides = x.strides();
|
||||
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int a = axes[i];
|
||||
shape.erase(shape.begin() + a);
|
||||
strides.erase(strides.begin() + a);
|
||||
}
|
||||
|
||||
return std::make_pair(shape, strides);
|
||||
}
|
||||
|
||||
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
|
||||
// The data is all there and we are reducing over everything
|
||||
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
|
||||
x.flags().contiguous) {
|
||||
return ContiguousAllReduce;
|
||||
}
|
||||
|
||||
// Row contiguous input so the output is row contiguous
|
||||
if (x.flags().row_contiguous) {
|
||||
// Merge consecutive axes
|
||||
std::vector<int> shape = {x.shape(axes[0])};
|
||||
std::vector<size_t> strides = {x.strides()[axes[0]]};
|
||||
for (int i = 1; i < axes.size(); i++) {
|
||||
if (axes[i] - 1 == axes[i - 1]) {
|
||||
shape.back() *= x.shape(axes[i]);
|
||||
strides.back() = x.strides()[axes[i]];
|
||||
} else {
|
||||
shape.push_back(x.shape(axes[i]));
|
||||
strides.push_back(x.strides()[axes[i]]);
|
||||
}
|
||||
}
|
||||
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(ContiguousReduce, shape, strides);
|
||||
} else if (strides.back() > 1) {
|
||||
return ReductionPlan(ContiguousStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
// Let's check if we can optimize our access patterns
|
||||
//
|
||||
// 1. We have a reduction axis with stride 1. Simply call
|
||||
// GeneralContiguousReduce and be done with it.
|
||||
// 2. We have transpositions and we are not reducing over the axis with
|
||||
// stride 1. However, we are reducing over an axis where everything is
|
||||
// contiguous in memory to the right of that axis. We can call strided
|
||||
// reduce and be done with it.
|
||||
// 2. We have weird transpositions and expands. Copy the strides to the
|
||||
// output, then call strided reduce.
|
||||
|
||||
// Sort reduction axes by stride in order to merge them and figure out if we
|
||||
// have a contiguous reduction.
|
||||
std::vector<std::pair<int, size_t>> reductions;
|
||||
for (auto a : axes) {
|
||||
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
|
||||
}
|
||||
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
|
||||
return a.second > b.second;
|
||||
});
|
||||
// Extract the two smallest and try to merge them in case the contiguous
|
||||
// reduction can be bigger than just the last axis.
|
||||
for (int i = reductions.size() - 1; i >= 1; i--) {
|
||||
auto a = reductions[i];
|
||||
auto b = reductions[i - 1];
|
||||
|
||||
// b.stride = a.shape * a.stride then a and b are contiguous
|
||||
if (b.second == a.first * a.second) {
|
||||
reductions.erase(reductions.begin() + i);
|
||||
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> shape;
|
||||
std::vector<size_t> strides;
|
||||
for (auto r : reductions) {
|
||||
shape.push_back(r.first);
|
||||
strides.push_back(r.second);
|
||||
}
|
||||
|
||||
// We can call the contiguous reduction op for every weird way the input is
|
||||
// structured in the rest of the axes.
|
||||
if (strides.back() == 1) {
|
||||
return ReductionPlan(GeneralContiguousReduce, shape, strides);
|
||||
}
|
||||
|
||||
// Delegate to the general strided reduction op if the axes after
|
||||
// strides.back() are contiguous.
|
||||
if (strides.back() > 1) {
|
||||
int size = 1;
|
||||
for (int i = x.ndim() - 1; i >= 0; i--) {
|
||||
if (axes.back() == i) {
|
||||
continue;
|
||||
}
|
||||
if (x.strides()[i] != size) {
|
||||
break;
|
||||
}
|
||||
size *= x.shape(i);
|
||||
}
|
||||
if (size >= strides.back()) {
|
||||
return ReductionPlan(GeneralStridedReduce, shape, strides);
|
||||
}
|
||||
}
|
||||
|
||||
return ReductionPlan(GeneralReduce, shape, strides);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -234,7 +234,7 @@ void scan_dispatch(
|
||||
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
|
||||
auto init = (issubdtype(input.dtype(), floating))
|
||||
? static_cast<U>(-std::numeric_limits<float>::infinity())
|
||||
: std::numeric_limits<U>::min();
|
||||
: std::numeric_limits<U>::max();
|
||||
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
|
||||
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
|
||||
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);
|
||||
|
@@ -1,52 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
std::vector<int>& start_indices,
|
||||
std::vector<int>& strides) {
|
||||
int64_t data_offset = 0;
|
||||
bool copy_needed = false;
|
||||
std::vector<int64_t> inp_strides(in.ndim(), 0);
|
||||
for (int i = 0; i < in.ndim(); ++i) {
|
||||
data_offset += start_indices[i] * in.strides()[i];
|
||||
inp_strides[i] = in.strides()[i] * strides[i];
|
||||
|
||||
copy_needed |= strides[i] < 0;
|
||||
}
|
||||
|
||||
return std::make_tuple(copy_needed, data_offset, inp_strides);
|
||||
}
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out) {
|
||||
// Compute row/col contiguity
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(out.shape(), out_strides);
|
||||
|
||||
auto flags = in.flags();
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
|
||||
if (data_size == 1) {
|
||||
// Broadcasted scalar array is contiguous.
|
||||
flags.contiguous = true;
|
||||
} else if (data_size == in.data_size()) {
|
||||
// Means we sliced a broadcasted dimension so leave the "no holes" flag
|
||||
// alone.
|
||||
} else {
|
||||
// We sliced something. So either we are row or col contiguous or we
|
||||
// punched a hole.
|
||||
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
|
||||
}
|
||||
|
||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,20 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
|
||||
const array& in,
|
||||
std::vector<int>& start_indices,
|
||||
std::vector<int>& strides);
|
||||
|
||||
void shared_buffer_slice(
|
||||
const array& in,
|
||||
const std::vector<size_t>& out_strides,
|
||||
size_t data_offset,
|
||||
array& out);
|
||||
|
||||
} // namespace mlx::core
|
@@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) {
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = out.shape();
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = out.strides();
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = out.strides()[axis];
|
||||
int axis_size = out.shape(axis);
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting in place
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
@@ -143,42 +143,34 @@ void argsort(const array& in, array& out, int axis) {
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto in_remaining_shape = in.shape();
|
||||
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto in_remaining_strides = in.strides();
|
||||
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
auto out_remaining_shape = out.shape();
|
||||
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
|
||||
|
||||
auto out_remaining_strides = out.strides();
|
||||
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
|
||||
|
||||
size_t in_stride = in.strides()[axis];
|
||||
size_t out_stride = out.strides()[axis];
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
|
||||
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + in_loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||
|
||||
StridedIterator st_(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, out_stride, axis_size);
|
||||
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, out_stride, 0);
|
||||
StridedIterator ed(idx_ptr, out_stride, axis_size);
|
||||
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * in_stride];
|
||||
auto v2 = data_ptr[b * in_stride];
|
||||
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * axis_stride];
|
||||
auto v2 = data_ptr[b * axis_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
|
@@ -29,15 +29,6 @@ inline size_t elem_to_loc(int elem, const array& a) {
|
||||
return elem_to_loc(elem, a.shape(), a.strides());
|
||||
}
|
||||
|
||||
template <typename stride_t>
|
||||
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
|
||||
std::vector<stride_t> strides(shape.size(), 1);
|
||||
for (int i = shape.size() - 1; i > 0; i--) {
|
||||
strides[i - 1] = strides[i] * shape[i];
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
// Collapse dims that are contiguous to possibly route to a better kernel
|
||||
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
||||
// should return {{2, 4}, {{1, 2}}}.
|
||||
|
@@ -18,7 +18,7 @@ function(make_jit_source SRC_FILE)
|
||||
${CMAKE_C_COMPILER}
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${SRC_FILE}
|
||||
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
|
||||
"-D${MLX_METAL_VERSION}"
|
||||
DEPENDS make_compiled_preamble.sh
|
||||
kernels/${SRC_FILE}.h
|
||||
${ARGN}
|
||||
@@ -52,7 +52,6 @@ make_jit_source(
|
||||
)
|
||||
make_jit_source(scatter)
|
||||
make_jit_source(gather)
|
||||
make_jit_source(hadamard)
|
||||
|
||||
if (MLX_METAL_JIT)
|
||||
target_sources(
|
||||
@@ -65,11 +64,6 @@ if (MLX_METAL_JIT)
|
||||
make_jit_source(unary)
|
||||
make_jit_source(binary)
|
||||
make_jit_source(binary_two)
|
||||
make_jit_source(
|
||||
fft
|
||||
kernels/fft/radix.h
|
||||
kernels/fft/readwrite.h
|
||||
)
|
||||
make_jit_source(ternary)
|
||||
make_jit_source(softmax)
|
||||
make_jit_source(scan)
|
||||
@@ -113,8 +107,6 @@ if (MLX_METAL_JIT)
|
||||
kernels/steel/defines.h
|
||||
kernels/steel/conv/loaders/loader_general.h
|
||||
)
|
||||
make_jit_source(quantized)
|
||||
make_jit_source(gemv_masked)
|
||||
else()
|
||||
target_sources(
|
||||
mlx
|
||||
@@ -134,7 +126,6 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
|
||||
@@ -144,13 +135,11 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_PATH)
|
||||
|
@@ -242,17 +242,8 @@ void MetalAllocator::free(Buffer buffer) {
|
||||
}
|
||||
|
||||
MetalAllocator& allocator() {
|
||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
|
||||
// not be called on exit and all the buffers will be leaked. This is necessary
|
||||
// because releasing buffers can take more than 30sec when the program holds a
|
||||
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
|
||||
// users when exiting.
|
||||
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
|
||||
// when applying this pattern to more places, or when introducing sanitizers
|
||||
// to MLX.
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
|
||||
static MetalAllocator* allocator_ = new MetalAllocator;
|
||||
return *allocator_;
|
||||
static MetalAllocator allocator_;
|
||||
return allocator_;
|
||||
}
|
||||
|
||||
size_t set_cache_limit(size_t limit) {
|
||||
|
@@ -6,62 +6,20 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define BINARY_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this)); \
|
||||
}
|
||||
|
||||
#define BINARY_GPU_MULTI(func) \
|
||||
void func::eval_gpu( \
|
||||
const std::vector<array>& inputs, std::vector<array>& outputs) { \
|
||||
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
|
||||
|
||||
std::string get_kernel_name(
|
||||
BinaryOpType bopt,
|
||||
const std::string& op,
|
||||
const array& a,
|
||||
bool use_2d,
|
||||
int ndim) {
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << (use_2d ? "sv2" : "sv");
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << (use_2d ? "vs2" : "vs");
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << (use_2d ? "vv2" : "vv");
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << ndim;
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
return kname.str();
|
||||
}
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (out.size() == 0) {
|
||||
@@ -74,12 +32,39 @@ void binary_op_gpu_inplace(
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel =
|
||||
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
|
||||
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
@@ -123,11 +108,9 @@ void binary_op_gpu_inplace(
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d
|
||||
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
@@ -137,36 +120,15 @@ void binary_op_gpu_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
void binary_op(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
array& out,
|
||||
const std::string op) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op) {
|
||||
auto& s = outputs[0].primitive().stream();
|
||||
binary_op_gpu(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -177,11 +139,39 @@ void binary_op_gpu_inplace(
|
||||
auto& strides_b = strides[1];
|
||||
auto& strides_out = strides[2];
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (bopt) {
|
||||
case BinaryOpType::ScalarScalar:
|
||||
kname << "ss";
|
||||
break;
|
||||
case BinaryOpType::ScalarVector:
|
||||
kname << "sv";
|
||||
break;
|
||||
case BinaryOpType::VectorScalar:
|
||||
kname << "vs";
|
||||
break;
|
||||
case BinaryOpType::VectorVector:
|
||||
kname << "vv";
|
||||
break;
|
||||
case BinaryOpType::General:
|
||||
kname << "g";
|
||||
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
|
||||
kname << shape.size();
|
||||
} else {
|
||||
kname << "n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
kname << op << type_to_name(a);
|
||||
kernel_name = kname.str();
|
||||
}
|
||||
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
|
||||
auto kernel = get_binary_kernel(d, kernel_name, a, out);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
@@ -218,11 +208,10 @@ void binary_op_gpu_inplace(
|
||||
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Launch a 1D or 2D grid of threads
|
||||
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
// Launch a 1D grid of threads
|
||||
size_t nthreads =
|
||||
bopt == BinaryOpType::General ? out.size() : out.data_size();
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
@@ -232,65 +221,102 @@ void binary_op_gpu_inplace(
|
||||
}
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
binary_op_gpu_inplace(inputs, out, op, s);
|
||||
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "add");
|
||||
}
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op) {
|
||||
auto& s = out.primitive().stream();
|
||||
binary_op_gpu(inputs, out, op, s);
|
||||
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "arctan2");
|
||||
}
|
||||
|
||||
BINARY_GPU(Add)
|
||||
BINARY_GPU(ArcTan2)
|
||||
BINARY_GPU(Divide)
|
||||
BINARY_GPU_MULTI(DivMod)
|
||||
BINARY_GPU(Remainder)
|
||||
BINARY_GPU(Equal)
|
||||
BINARY_GPU(Greater)
|
||||
BINARY_GPU(GreaterEqual)
|
||||
BINARY_GPU(Less)
|
||||
BINARY_GPU(LessEqual)
|
||||
BINARY_GPU(LogicalAnd)
|
||||
BINARY_GPU(LogicalOr)
|
||||
BINARY_GPU(LogAddExp)
|
||||
BINARY_GPU(Maximum)
|
||||
BINARY_GPU(Minimum)
|
||||
BINARY_GPU(Multiply)
|
||||
BINARY_GPU(NotEqual)
|
||||
BINARY_GPU(Power)
|
||||
BINARY_GPU(Subtract)
|
||||
|
||||
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
switch (op_) {
|
||||
case BitwiseBinary::And:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op(inputs, out, "bitwise_and");
|
||||
break;
|
||||
case BitwiseBinary::Or:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op(inputs, out, "bitwise_or");
|
||||
break;
|
||||
case BitwiseBinary::Xor:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op(inputs, out, "bitwise_xor");
|
||||
break;
|
||||
case BitwiseBinary::LeftShift:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op(inputs, out, "left_shift");
|
||||
break;
|
||||
case BitwiseBinary::RightShift:
|
||||
binary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
binary_op(inputs, out, "right_shift");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "div");
|
||||
}
|
||||
|
||||
void DivMod::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
binary_op(inputs, outputs, "divmod");
|
||||
}
|
||||
|
||||
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "rem");
|
||||
}
|
||||
|
||||
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
|
||||
}
|
||||
|
||||
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "ge");
|
||||
}
|
||||
|
||||
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "geq");
|
||||
}
|
||||
|
||||
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "le");
|
||||
}
|
||||
|
||||
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "leq");
|
||||
}
|
||||
|
||||
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "land");
|
||||
}
|
||||
|
||||
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lor");
|
||||
}
|
||||
|
||||
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "lae");
|
||||
}
|
||||
|
||||
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "max");
|
||||
}
|
||||
|
||||
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "min");
|
||||
}
|
||||
|
||||
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "mul");
|
||||
}
|
||||
|
||||
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "neq");
|
||||
}
|
||||
|
||||
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "pow");
|
||||
}
|
||||
|
||||
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
binary_op(inputs, out, "sub");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,33 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs,
|
||||
const std::string& op,
|
||||
const Stream& s);
|
||||
|
||||
void binary_op_gpu_inplace(
|
||||
const std::vector<array>& inputs,
|
||||
array& out,
|
||||
const std::string& op,
|
||||
const Stream& s);
|
||||
|
||||
} // namespace mlx::core
|
@@ -56,15 +56,12 @@ inline void build_kernel(
|
||||
} else {
|
||||
add_indices = true;
|
||||
os << " device const " << get_type_string(x.dtype()) << "* " << xname
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
|
||||
<< " [[buffer(" << cnt++ << ")]]," << std::endl
|
||||
<< " constant const size_t* " << xname << "_strides [[buffer("
|
||||
<< cnt++ << ")]]," << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (add_indices) {
|
||||
os << " constant const size_t* in_strides [[buffer(" << cnt++
|
||||
<< ")]],\n";
|
||||
}
|
||||
|
||||
// Add the output arguments
|
||||
for (auto& x : outputs) {
|
||||
os << " device " << get_type_string(x.dtype()) << "* "
|
||||
@@ -113,17 +110,13 @@ inline void build_kernel(
|
||||
}
|
||||
|
||||
// Read the inputs in tmps
|
||||
int nc_in_count = 0;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& x = inputs[i];
|
||||
for (auto& x : inputs) {
|
||||
auto& xname = namer.get_name(x);
|
||||
|
||||
if (is_constant(x)) {
|
||||
auto type_str = get_type_string(x.dtype());
|
||||
os << " auto tmp_" << xname << " = static_cast<"
|
||||
<< get_type_string(x.dtype()) << ">(";
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
|
||||
print_constant(os, x);
|
||||
os << ");" << std::endl;
|
||||
os << ";" << std::endl;
|
||||
} else if (is_scalar(x)) {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[0];" << std::endl;
|
||||
@@ -131,20 +124,17 @@ inline void build_kernel(
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[index];" << std::endl;
|
||||
} else if (!dynamic_dims) {
|
||||
int offset = nc_in_count * ndim;
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[";
|
||||
os << "index_0 * " << "in_strides[" << offset << "]";
|
||||
os << "index_0 * " << xname << "_strides[0]";
|
||||
for (int i = 1; i < ndim; i++) {
|
||||
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
|
||||
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
|
||||
}
|
||||
os << "];" << std::endl;
|
||||
nc_in_count++;
|
||||
} else {
|
||||
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
|
||||
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
|
||||
<< nc_in_count * ndim << ", ndim)];" << std::endl;
|
||||
nc_in_count++;
|
||||
<< xname << "[elem_to_loc(index, output_shape, " << xname
|
||||
<< "_strides, ndim)];" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -306,7 +296,6 @@ void Compiled::eval_gpu(
|
||||
// Put the inputs in
|
||||
int cnt = 0;
|
||||
int stride_idx = 1; // idx 0 is the output strides
|
||||
std::vector<size_t> in_strides;
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
|
||||
continue;
|
||||
@@ -314,17 +303,13 @@ void Compiled::eval_gpu(
|
||||
auto& x = inputs[i];
|
||||
compute_encoder.set_input_array(x, cnt++);
|
||||
if (!contiguous && !is_scalar(x)) {
|
||||
in_strides.insert(
|
||||
in_strides.end(),
|
||||
strides[stride_idx].begin(),
|
||||
strides[stride_idx].end());
|
||||
compute_encoder->setBytes(
|
||||
strides[stride_idx].data(),
|
||||
strides[stride_idx].size() * sizeof(size_t),
|
||||
cnt++);
|
||||
stride_idx++;
|
||||
}
|
||||
}
|
||||
if (!in_strides.empty()) {
|
||||
compute_encoder->setBytes(
|
||||
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
|
@@ -33,6 +33,9 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
}
|
||||
@@ -54,27 +57,22 @@ void copy_gpu_inplace(
|
||||
int64_t out_offset,
|
||||
CopyType ctype,
|
||||
const Stream& s) {
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to collapse contiguous dims
|
||||
auto [shape, strides] = collapse_contiguous_dims(
|
||||
data_shape, std::vector{strides_in_pre, strides_out_pre});
|
||||
auto& strides_in_ = strides[0];
|
||||
auto& strides_out_ = strides[1];
|
||||
|
||||
bool use_2d = out.data_size() > UINT32_MAX;
|
||||
auto& d = metal::device(s.device);
|
||||
std::string kernel_name;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
switch (ctype) {
|
||||
case CopyType::Scalar:
|
||||
kname << (use_2d ? "s2" : "s");
|
||||
kname << "s";
|
||||
break;
|
||||
case CopyType::Vector:
|
||||
kname << (use_2d ? "v2" : "v");
|
||||
kname << "v";
|
||||
break;
|
||||
case CopyType::General:
|
||||
kname << "g";
|
||||
@@ -140,8 +138,7 @@ void copy_gpu_inplace(
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
size_t nthreads = out.data_size();
|
||||
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
|
||||
: MTL::Size(nthreads, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
thread_group_size = nthreads;
|
||||
|
@@ -14,6 +14,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
@@ -29,29 +30,13 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
constexpr auto get_metal_version() {
|
||||
#if (MLX_METAL_VERSION >= 320)
|
||||
return MTL::LanguageVersion3_2;
|
||||
#elif (MLX_METAL_VERSION >= 310)
|
||||
#if defined METAL_3_1
|
||||
return MTL::LanguageVersion3_1;
|
||||
#else
|
||||
return MTL::LanguageVersion3_0;
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
}
|
||||
|
||||
auto load_device() {
|
||||
auto devices = MTL::CopyAllDevices();
|
||||
auto device = static_cast<MTL::Device*>(devices->object(0))
|
||||
@@ -139,49 +124,6 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
}
|
||||
|
||||
CommandEncoder::~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
}
|
||||
|
||||
void CommandEncoder::set_input_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_output_array(
|
||||
array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
@@ -311,9 +253,13 @@ void Device::register_library(
|
||||
}
|
||||
}
|
||||
|
||||
void Device::register_library(const std::string& lib_name) {
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::function<std::string(const std::string&)>& lib_path_func) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
std::string new_lib_path = lib_path_func(lib_name);
|
||||
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,7 +269,7 @@ MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
|
||||
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
|
||||
mtl_lib = it->second;
|
||||
} else { // Look for metallib alongside library
|
||||
register_library(lib_name, get_colocated_mtllib_path(lib_name));
|
||||
register_library(lib_name);
|
||||
mtl_lib = library_map_[lib_name];
|
||||
}
|
||||
|
||||
|
@@ -9,16 +9,38 @@
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/device.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
|
||||
Dl_info info;
|
||||
std::string mtllib_path;
|
||||
std::string lib_ext = lib_name + ".metallib";
|
||||
|
||||
int success = dladdr((void*)get_colocated_mtllib_path, &info);
|
||||
if (success) {
|
||||
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
|
||||
mtllib_path = mtllib.c_str();
|
||||
}
|
||||
|
||||
return mtllib_path;
|
||||
}
|
||||
|
||||
using MTLFCList =
|
||||
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
|
||||
|
||||
struct CommandEncoder {
|
||||
CommandEncoder(MTL::CommandBuffer* cbuf);
|
||||
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
};
|
||||
CommandEncoder(const CommandEncoder&) = delete;
|
||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||
|
||||
@@ -41,8 +63,34 @@ struct CommandEncoder {
|
||||
return enc;
|
||||
}
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0) {
|
||||
auto r_buf =
|
||||
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
|
||||
@@ -50,7 +98,10 @@ struct CommandEncoder {
|
||||
return ConcurrentContext(*this);
|
||||
}
|
||||
|
||||
~CommandEncoder();
|
||||
~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
}
|
||||
|
||||
private:
|
||||
void maybe_split();
|
||||
@@ -85,8 +136,10 @@ class Device {
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path);
|
||||
|
||||
void register_library(const std::string& lib_name);
|
||||
void register_library(
|
||||
const std::string& lib_name,
|
||||
const std::function<std::string(const std::string&)>& lib_path_func =
|
||||
get_colocated_mtllib_path);
|
||||
|
||||
MTL::Library* get_library(const std::string& name);
|
||||
|
||||
|
@@ -1,803 +1,106 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <complex>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/backend/metal/binary.h"
|
||||
// Copyright © 2023 Apple Inc.
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/unary.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
|
||||
|
||||
#define MAX_STOCKHAM_FFT_SIZE 4096
|
||||
#define MAX_RADER_FFT_SIZE 2048
|
||||
#define MAX_BLUESTEIN_FFT_SIZE 2048
|
||||
// Threadgroup memory batching improves throughput for small n
|
||||
#define MIN_THREADGROUP_MEM_SIZE 256
|
||||
// For strided reads/writes, coalesce at least this many complex64s
|
||||
#define MIN_COALESCE_WIDTH 4
|
||||
|
||||
inline const std::vector<int> supported_radices() {
|
||||
// Ordered by preference in decomposition.
|
||||
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
|
||||
}
|
||||
|
||||
std::vector<int> prime_factors(int n) {
|
||||
int z = 2;
|
||||
std::vector<int> factors;
|
||||
while (z * z <= n) {
|
||||
if (n % z == 0) {
|
||||
factors.push_back(z);
|
||||
n /= z;
|
||||
} else {
|
||||
z++;
|
||||
}
|
||||
}
|
||||
if (n > 1) {
|
||||
factors.push_back(n);
|
||||
}
|
||||
return factors;
|
||||
}
|
||||
|
||||
struct FourStepParams {
|
||||
bool required = false;
|
||||
bool first_step = true;
|
||||
int n1 = 0;
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
// Forward Declaration
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
const Stream& s);
|
||||
|
||||
struct FFTPlan {
|
||||
int n = 0;
|
||||
// Number of steps for each radix in the Stockham decomposition
|
||||
std::vector<int> stockham;
|
||||
// Number of steps for each radix in the Rader decomposition
|
||||
std::vector<int> rader;
|
||||
// Rader factor, 1 if no rader factors
|
||||
int rader_n = 1;
|
||||
int bluestein_n = -1;
|
||||
// Four step FFT
|
||||
bool four_step = false;
|
||||
int n1 = 0;
|
||||
int n2 = 0;
|
||||
};
|
||||
|
||||
int next_fast_n(int n) {
|
||||
return next_power_of_2(n);
|
||||
}
|
||||
|
||||
std::vector<int> plan_stockham_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::vector<int> plan(radices.size(), 0);
|
||||
int orig_n = n;
|
||||
if (n == 1) {
|
||||
return plan;
|
||||
}
|
||||
for (int i = 0; i < radices.size(); i++) {
|
||||
int radix = radices[i];
|
||||
// Manually tuned radices for powers of 2
|
||||
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
|
||||
continue;
|
||||
}
|
||||
while (n % radix == 0) {
|
||||
plan[i] += 1;
|
||||
n /= radix;
|
||||
if (n == 1) {
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Unplannable");
|
||||
}
|
||||
|
||||
FFTPlan plan_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> radices_set(radices.begin(), radices.end());
|
||||
|
||||
FFTPlan plan;
|
||||
plan.n = n;
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
auto factors = prime_factors(n);
|
||||
int remaining_n = n;
|
||||
|
||||
// Four Step FFT when N is too large for shared mem.
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
// For power's of two we have a fast, no transpose four step implementation.
|
||||
plan.four_step = true;
|
||||
// Rough heuristic for choosing faster powers of two when we can
|
||||
plan.n2 = n > 65536 ? 1024 : 64;
|
||||
plan.n1 = n / plan.n2;
|
||||
return plan;
|
||||
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
// Otherwise we use a multi-upload Bluestein's
|
||||
plan.four_step = true;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
return plan;
|
||||
}
|
||||
|
||||
for (int factor : factors) {
|
||||
// Make sure the factor is a supported radix
|
||||
if (radices_set.find(factor) == radices_set.end()) {
|
||||
// We only support a single Rader factor currently
|
||||
// TODO(alexbarron) investigate weirdness with large
|
||||
// Rader sizes -- possibly a compiler issue?
|
||||
if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
return plan;
|
||||
}
|
||||
// See if we can use Rader's algorithm to Stockham decompose n - 1
|
||||
auto rader_factors = prime_factors(factor - 1);
|
||||
int last_factor = -1;
|
||||
for (int rf : rader_factors) {
|
||||
// We don't nest Rader's algorithm so if `factor - 1`
|
||||
// isn't Stockham decomposable we give up and do Bluestein's.
|
||||
if (radices_set.find(rf) == radices_set.end()) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
plan.rader = plan_stockham_fft(factor - 1);
|
||||
plan.rader_n = factor;
|
||||
remaining_n /= factor;
|
||||
}
|
||||
}
|
||||
|
||||
plan.stockham = plan_stockham_fft(remaining_n);
|
||||
return plan;
|
||||
}
|
||||
|
||||
int compute_elems_per_thread(FFTPlan plan) {
|
||||
// Heuristics for selecting an efficient number
|
||||
// of threads to use for a particular mixed-radix FFT.
|
||||
auto n = plan.n;
|
||||
|
||||
std::vector<int> steps;
|
||||
auto radices = supported_radices();
|
||||
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
|
||||
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
|
||||
std::set<int> used_radices;
|
||||
for (int i = 0; i < steps.size(); i++) {
|
||||
int radix = radices[i % radices.size()];
|
||||
if (steps[i] > 0) {
|
||||
used_radices.insert(radix);
|
||||
}
|
||||
}
|
||||
|
||||
// Manual tuning for 7/11/13
|
||||
if (used_radices.find(7) != used_radices.end() &&
|
||||
(used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end())) {
|
||||
return 7;
|
||||
} else if (
|
||||
used_radices.find(11) != used_radices.end() &&
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return 11;
|
||||
}
|
||||
|
||||
// TODO(alexbarron) Some really weird stuff is going on
|
||||
// for certain `elems_per_thread` on large composite n.
|
||||
// Possibly a compiler issue?
|
||||
if (n == 3159)
|
||||
return 13;
|
||||
if (n == 3645)
|
||||
return 5;
|
||||
if (n == 3969)
|
||||
return 7;
|
||||
if (n == 1982)
|
||||
return 5;
|
||||
|
||||
if (used_radices.size() == 1) {
|
||||
return *(used_radices.begin());
|
||||
}
|
||||
if (used_radices.size() == 2) {
|
||||
if (used_radices.find(11) != used_radices.end() ||
|
||||
used_radices.find(13) != used_radices.end()) {
|
||||
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
|
||||
}
|
||||
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
|
||||
return radix_vec[1];
|
||||
}
|
||||
// In all other cases use the second smallest radix.
|
||||
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
|
||||
return radix_vec[1];
|
||||
}
|
||||
|
||||
// Rader
|
||||
int mod_exp(int x, int y, int n) {
|
||||
int out = 1;
|
||||
while (y) {
|
||||
if (y & 1) {
|
||||
out = out * x % n;
|
||||
}
|
||||
y >>= 1;
|
||||
x = x * x % n;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
int primitive_root(int n) {
|
||||
auto factors = prime_factors(n - 1);
|
||||
|
||||
for (int r = 2; r < n - 1; r++) {
|
||||
bool found = true;
|
||||
for (int factor : factors) {
|
||||
if (mod_exp(r, (n - 1) / factor, n) == 1) {
|
||||
found = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> compute_raders_constants(
|
||||
int rader_n,
|
||||
const Stream& s) {
|
||||
int proot = primitive_root(rader_n);
|
||||
// Fermat's little theorem
|
||||
int inv = mod_exp(proot, rader_n - 2, rader_n);
|
||||
std::vector<short> g_q(rader_n - 1);
|
||||
std::vector<short> g_minus_q(rader_n - 1);
|
||||
for (int i = 0; i < rader_n - 1; i++) {
|
||||
g_q[i] = mod_exp(proot, i, rader_n);
|
||||
g_minus_q[i] = mod_exp(inv, i, rader_n);
|
||||
}
|
||||
array g_q_arr(g_q.begin(), {rader_n - 1});
|
||||
array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1});
|
||||
|
||||
std::vector<std::complex<float>> b_q(rader_n - 1);
|
||||
for (int i = 0; i < rader_n - 1; i++) {
|
||||
float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n;
|
||||
b_q[i] = std::exp(std::complex<float>(0, pi_i));
|
||||
}
|
||||
|
||||
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
|
||||
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
|
||||
auto b_q_fft_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
|
||||
std::ptrdiff_t item_size = b_q_fft.itemsize();
|
||||
size_t fft_size = rader_n - 1;
|
||||
// This FFT is always small (<4096, batch 1) so save some overhead
|
||||
// and do it on the CPU
|
||||
pocketfft::c2c(
|
||||
/* shape= */ {fft_size},
|
||||
/* stride_in= */ {item_size},
|
||||
/* stride_out= */ {item_size},
|
||||
/* axes= */ {0},
|
||||
/* forward= */ true,
|
||||
/* data_in= */ b_q.data(),
|
||||
/* data_out= */ b_q_fft_ptr,
|
||||
/* scale= */ 1.0f);
|
||||
return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr);
|
||||
}
|
||||
|
||||
// Bluestein
|
||||
std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
|
||||
// We need to calculate the Bluestein twiddle factors
|
||||
// in double precision for the overall numerical stability
|
||||
// of Bluestein's FFT algorithm to be acceptable.
|
||||
//
|
||||
// Metal doesn't support float64, so instead we
|
||||
// manually implement the required operations on cpu.
|
||||
//
|
||||
// In numpy:
|
||||
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
|
||||
// w_q = np.fft.fft(1/w_k)
|
||||
// return w_k, w_q
|
||||
int length = 2 * n - 1;
|
||||
|
||||
std::vector<std::complex<float>> w_k_vec(n);
|
||||
std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
|
||||
|
||||
for (int i = -n + 1; i < n; i++) {
|
||||
double theta = pow(i, 2) * M_PI / (double)n;
|
||||
w_q_vec[i + n - 1] = std::exp(std::complex<double>(0, theta));
|
||||
if (i >= 0) {
|
||||
w_k_vec[i] = std::exp(std::complex<double>(0, -theta));
|
||||
}
|
||||
}
|
||||
|
||||
array w_k({n}, complex64, nullptr, {});
|
||||
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
|
||||
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
|
||||
|
||||
array w_q({bluestein_n}, complex64, nullptr, {});
|
||||
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
|
||||
auto w_q_ptr =
|
||||
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
|
||||
|
||||
std::ptrdiff_t item_size = w_q.itemsize();
|
||||
size_t fft_size = bluestein_n;
|
||||
pocketfft::c2c(
|
||||
/* shape= */ {fft_size},
|
||||
/* stride_in= */ {item_size},
|
||||
/* stride_out= */ {item_size},
|
||||
/* axes= */ {0},
|
||||
/* forward= */ true,
|
||||
/* data_in= */ w_q_vec.data(),
|
||||
/* data_out= */ w_q_ptr,
|
||||
/* scale= */ 1.0f);
|
||||
return std::make_tuple(w_k, w_q);
|
||||
}
|
||||
|
||||
void multi_upload_bluestein_fft(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
const Stream& s) {
|
||||
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
||||
// algorithm
|
||||
int n = inverse ? out.shape(axis) : in.shape(axis);
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
|
||||
// Broadcast w_q and w_k to the batch size
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
b_strides[axis] = 1;
|
||||
array w_k_broadcast({}, complex64, nullptr, {});
|
||||
array w_q_broadcast({}, complex64, nullptr, {});
|
||||
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
|
||||
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
|
||||
|
||||
auto temp_shape = inverse ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
|
||||
if (real && !inverse) {
|
||||
// Convert float32->complex64
|
||||
copy_gpu(in, temp, CopyType::General, s);
|
||||
} else if (real && inverse) {
|
||||
int back_offset = n % 2 == 0 ? 2 : 1;
|
||||
auto slice_shape = in.shape();
|
||||
slice_shape[axis] -= back_offset;
|
||||
array slice_temp(slice_shape, complex64, nullptr, {});
|
||||
array conj_temp(in.shape(), complex64, nullptr, {});
|
||||
copies.push_back(slice_temp);
|
||||
copies.push_back(conj_temp);
|
||||
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
rstarts[axis] = in.shape(axis) - back_offset;
|
||||
rstrides[axis] = -1;
|
||||
unary_op_gpu({in}, conj_temp, "Conjugate", s);
|
||||
slice_gpu(in, slice_temp, rstarts, rstrides, s);
|
||||
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
|
||||
} else if (inverse) {
|
||||
unary_op_gpu({in}, temp, "Conjugate", s);
|
||||
} else {
|
||||
temp.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
std::vector<std::pair<int, int>> pads;
|
||||
auto padded_shape = out.shape();
|
||||
padded_shape[axis] = plan.bluestein_n;
|
||||
array pad_temp(padded_shape, complex64, nullptr, {});
|
||||
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
|
||||
|
||||
array pad_temp1(padded_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
pad_temp,
|
||||
pad_temp1,
|
||||
axis,
|
||||
/*inverse=*/false,
|
||||
/*real=*/false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/false,
|
||||
s);
|
||||
|
||||
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
|
||||
|
||||
fft_op(
|
||||
pad_temp,
|
||||
pad_temp1,
|
||||
axis,
|
||||
/* inverse= */ true,
|
||||
/* real= */ false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/true,
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
std::vector<int> strides(in.ndim(), 1);
|
||||
starts[axis] = plan.bluestein_n - offset - n;
|
||||
slice_gpu(pad_temp1, temp, starts, strides, s);
|
||||
|
||||
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
|
||||
|
||||
if (real && !inverse) {
|
||||
std::vector<int> rstarts(in.ndim(), 0);
|
||||
std::vector<int> rstrides(in.ndim(), 1);
|
||||
slice_gpu(temp1, out, rstarts, strides, s);
|
||||
} else if (real && inverse) {
|
||||
std::vector<size_t> b_strides(in.ndim(), 0);
|
||||
auto inv_n = array({1.0f / n}, {1}, float32);
|
||||
array temp_float(out.shape(), out.dtype(), nullptr, {});
|
||||
copies.push_back(temp_float);
|
||||
copies.push_back(inv_n);
|
||||
|
||||
copy_gpu(temp1, temp_float, CopyType::General, s);
|
||||
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
|
||||
} else if (inverse) {
|
||||
auto inv_n = array({1.0f / n}, {1}, complex64);
|
||||
unary_op_gpu({temp1}, temp, "Conjugate", s);
|
||||
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
|
||||
copies.push_back(inv_n);
|
||||
} else {
|
||||
out.copy_shared_buffer(temp1);
|
||||
}
|
||||
|
||||
copies.push_back(w_k);
|
||||
copies.push_back(w_q);
|
||||
copies.push_back(w_k_broadcast);
|
||||
copies.push_back(w_q_broadcast);
|
||||
copies.push_back(temp);
|
||||
copies.push_back(temp1);
|
||||
copies.push_back(pad_temp);
|
||||
copies.push_back(pad_temp1);
|
||||
}
|
||||
|
||||
void four_step_fft(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
FFTPlan& plan,
|
||||
std::vector<array> copies,
|
||||
const Stream& s) {
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = out.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
if (plan.bluestein_n == -1) {
|
||||
// Fast no transpose implementation for powers of 2.
|
||||
FourStepParams four_step_params = {
|
||||
/* required= */ true, /* first_step= */ true, plan.n1, plan.n2};
|
||||
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
four_step_params.first_step = false;
|
||||
fft_op(
|
||||
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
}
|
||||
}
|
||||
auto& in = inputs[0];
|
||||
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
||||
if (n == 1) {
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
|
||||
in.dtype() != complex64 || out.dtype() != complex64) {
|
||||
// Could also fallback to CPU implementation here.
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
|
||||
}
|
||||
|
||||
if (four_step_params.required) {
|
||||
// Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows
|
||||
n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2;
|
||||
size_t n = in.shape(axes_[0]);
|
||||
|
||||
if (!is_power_of_2(n) || n > 2048 || n < 4) {
|
||||
throw std::runtime_error(
|
||||
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
|
||||
}
|
||||
|
||||
// Make sure that the array is contiguous and has stride 1 in the FFT dim
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&axis, &copies, &s](const array& x) {
|
||||
auto check_input = [this, &copies, &s](const array& x) {
|
||||
// TODO: Pass the strides to the kernel so
|
||||
// we can avoid the copy when x is not contiguous.
|
||||
bool no_copy = x.strides()[axis] == 1 &&
|
||||
(x.flags().row_contiguous || x.flags().col_contiguous);
|
||||
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
|
||||
x.flags().col_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
std::vector<size_t> strides;
|
||||
size_t cur_stride = x.shape(axis);
|
||||
for (int a = 0; a < x.ndim(); a++) {
|
||||
if (a == axis) {
|
||||
size_t cur_stride = x.shape(axes_[0]);
|
||||
for (int axis = 0; axis < x.ndim(); axis++) {
|
||||
if (axis == axes_[0]) {
|
||||
strides.push_back(1);
|
||||
} else {
|
||||
strides.push_back(cur_stride);
|
||||
cur_stride *= x.shape(a);
|
||||
cur_stride *= x.shape(axis);
|
||||
}
|
||||
}
|
||||
|
||||
auto flags = x.flags();
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(x.shape(), strides);
|
||||
|
||||
flags.col_contiguous = is_row_contiguous;
|
||||
flags.row_contiguous = is_col_contiguous;
|
||||
flags.contiguous = data_size == x_copy.size();
|
||||
size_t f_stride = 1;
|
||||
size_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
|
||||
f_stride *= x.shape(i);
|
||||
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
|
||||
b_stride *= x.shape(ri);
|
||||
}
|
||||
// This is probably over-conservative
|
||||
flags.contiguous = false;
|
||||
|
||||
x_copy.set_data(
|
||||
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
|
||||
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
|
||||
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
// real to complex: n -> (n/2)+1
|
||||
// complex to real: (n/2)+1 -> n
|
||||
auto out_strides = in_contiguous.strides();
|
||||
size_t out_data_size = in_contiguous.data_size();
|
||||
if (in.shape(axis) != out.shape(axis)) {
|
||||
for (int i = 0; i < out_strides.size(); i++) {
|
||||
if (out_strides[i] != 1) {
|
||||
out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis);
|
||||
}
|
||||
}
|
||||
out_data_size = out_data_size / in.shape(axis) * out.shape(axis);
|
||||
}
|
||||
|
||||
auto plan = plan_fft(n);
|
||||
if (plan.four_step) {
|
||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
return;
|
||||
}
|
||||
const array& in_contiguous = check_input(inputs[0]);
|
||||
|
||||
// TODO: allow donation here
|
||||
if (!inplace) {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
out_data_size,
|
||||
out_strides,
|
||||
in_contiguous.flags());
|
||||
}
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(out.nbytes()),
|
||||
in_contiguous.data_size(),
|
||||
in_contiguous.strides(),
|
||||
in_contiguous.flags());
|
||||
|
||||
auto radices = supported_radices();
|
||||
int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n;
|
||||
|
||||
// Setup function constants
|
||||
bool power_of_2 = is_power_of_2(fft_size);
|
||||
|
||||
auto make_int = [](int* a, int i) {
|
||||
return std::make_tuple(a, MTL::DataType::DataTypeInt, i);
|
||||
};
|
||||
auto make_bool = [](bool* a, int i) {
|
||||
return std::make_tuple(a, MTL::DataType::DataTypeBool, i);
|
||||
};
|
||||
|
||||
std::vector<MTLFC> func_consts = {
|
||||
make_bool(&inverse, 0), make_bool(&power_of_2, 1)};
|
||||
|
||||
// Start of radix/rader step constants
|
||||
int index = 4;
|
||||
for (int i = 0; i < plan.stockham.size(); i++) {
|
||||
func_consts.push_back(make_int(&plan.stockham[i], index));
|
||||
index += 1;
|
||||
}
|
||||
for (int i = 0; i < plan.rader.size(); i++) {
|
||||
func_consts.push_back(make_int(&plan.rader[i], index));
|
||||
index += 1;
|
||||
}
|
||||
int elems_per_thread = compute_elems_per_thread(plan);
|
||||
func_consts.push_back(make_int(&elems_per_thread, 2));
|
||||
|
||||
int rader_m = n / plan.rader_n;
|
||||
func_consts.push_back(make_int(&rader_m, 3));
|
||||
|
||||
// The overall number of FFTs we're going to compute for this input
|
||||
int size = out.dtype() == float32 ? out.size() : in.size();
|
||||
if (real && inverse && four_step_params.required) {
|
||||
size = out.size();
|
||||
}
|
||||
int total_batch_size = size / n;
|
||||
int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread;
|
||||
|
||||
// We batch among threadgroups for improved efficiency when n is small
|
||||
int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1);
|
||||
if (four_step_params.required) {
|
||||
// Require a threadgroup batch size of at least 4 for four step FFT
|
||||
// so we can coalesce the memory accesses.
|
||||
threadgroup_batch_size =
|
||||
std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH);
|
||||
}
|
||||
int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size);
|
||||
// FFTs up to 2^20 are currently supported
|
||||
assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE);
|
||||
|
||||
// ceil divide
|
||||
int batch_size =
|
||||
(total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size;
|
||||
|
||||
if (real && !four_step_params.required) {
|
||||
// We can perform 2 RFFTs at once so the batch size is halved.
|
||||
batch_size = (batch_size + 2 - 1) / 2;
|
||||
}
|
||||
int out_buffer_size = out.size();
|
||||
// We use n / 4 threads by default since radix-4
|
||||
// is the largest single threaded radix butterfly
|
||||
// we currently implement.
|
||||
size_t m = n / 4;
|
||||
size_t batch = in.size() / in.shape(axes_[0]);
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
|
||||
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
|
||||
// Only required by four step
|
||||
int step = -1;
|
||||
{
|
||||
std::ostringstream kname;
|
||||
std::string inv_string = inverse ? "true" : "false";
|
||||
std::string real_string = real ? "true" : "false";
|
||||
std::string func_name;
|
||||
if (plan.bluestein_n > 0) {
|
||||
kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
|
||||
<< in_type_str << "_" << out_type_str;
|
||||
func_name = "bluestein_fft";
|
||||
} else if (plan.rader_n > 1) {
|
||||
kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
|
||||
<< "_" << out_type_str;
|
||||
func_name = "rader_fft";
|
||||
} else if (four_step_params.required) {
|
||||
step = four_step_params.first_step ? 0 : 1;
|
||||
kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
|
||||
<< "_" << out_type_str << "_" << step << "_" << real_string;
|
||||
func_name = "four_step_fft";
|
||||
} else {
|
||||
kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
|
||||
<< out_type_str;
|
||||
func_name = "fft";
|
||||
}
|
||||
std::string base_name = kname.str();
|
||||
// We use a specialized kernel for each FFT size
|
||||
kname << "_n" << fft_size << "_inv_" << inverse;
|
||||
std::string hash_name = kname.str();
|
||||
auto template_def = func_name == "four_step_fft" ? get_template_definition(
|
||||
base_name,
|
||||
func_name,
|
||||
threadgroup_mem_size,
|
||||
in_type_str,
|
||||
out_type_str,
|
||||
step,
|
||||
real)
|
||||
: get_template_definition(
|
||||
base_name,
|
||||
func_name,
|
||||
threadgroup_mem_size,
|
||||
in_type_str,
|
||||
out_type_str);
|
||||
auto kernel =
|
||||
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
|
||||
kname << "fft_" << n;
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
|
||||
bool donated = in.data_shared_ptr() == nullptr;
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in_contiguous, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
|
||||
if (plan.bluestein_n > 0) {
|
||||
// Precomputed twiddle factors for Bluestein's
|
||||
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
|
||||
copies.push_back(w_q);
|
||||
copies.push_back(w_k);
|
||||
|
||||
compute_encoder.set_input_array(w_q, 2); // w_q
|
||||
compute_encoder.set_input_array(w_k, 3); // w_k
|
||||
compute_encoder->setBytes(&n, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
} else if (plan.rader_n > 1) {
|
||||
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
|
||||
copies.push_back(b_q);
|
||||
copies.push_back(g_q);
|
||||
copies.push_back(g_minus_q);
|
||||
|
||||
compute_encoder.set_input_array(b_q, 2);
|
||||
compute_encoder.set_input_array(g_q, 3);
|
||||
compute_encoder.set_input_array(g_minus_q, 4);
|
||||
compute_encoder->setBytes(&n, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
|
||||
} else if (four_step_params.required) {
|
||||
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(&n, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
|
||||
}
|
||||
|
||||
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
|
||||
auto grid_dims =
|
||||
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
auto group_dims = MTL::Size(1, m, 1);
|
||||
auto grid_dims = MTL::Size(batch, m, 1);
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
bool inplace,
|
||||
const Stream& s) {
|
||||
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
|
||||
}
|
||||
|
||||
void nd_fft_op(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real,
|
||||
const Stream& s) {
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
auto temp_shape = inverse ? in.shape() : out.shape();
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
bool inplace = reverse_index >= 3 && i != 0;
|
||||
// Opposite order for fft vs ifft
|
||||
int index = inverse ? reverse_index : i;
|
||||
size_t axis = axes[index];
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
|
||||
}
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& in = inputs[0];
|
||||
|
||||
if (axes_.size() > 1) {
|
||||
nd_fft_op(in, out, axes_, inverse_, real_, s);
|
||||
} else {
|
||||
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,203 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/hadamard.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
|
||||
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
|
||||
|
||||
std::string gen_hadamard_codelet(int m) {
|
||||
// Generate a O(m^2) hadamard codelet for a given M
|
||||
// using the hadamard matrices above
|
||||
//
|
||||
// e.g. m = 2
|
||||
// METAL_FUNC void hadamard_m(thread float *x) {
|
||||
// float tmp[2];
|
||||
// tmp[0] = + x[0] + x[1];
|
||||
// tmp[1] = + x[0] - x[1];
|
||||
// for (int i = 0; i < 2; i++) { x[i] = tmp[i]; }
|
||||
// }
|
||||
//
|
||||
auto h_matrices = hadamard_matrices();
|
||||
auto& matrix = h_matrices[m];
|
||||
|
||||
std::ostringstream source;
|
||||
source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl;
|
||||
if (m == 1) {
|
||||
source << "}" << std::endl;
|
||||
return source.str();
|
||||
}
|
||||
source << " float tmp[" << m << "];" << std::endl;
|
||||
auto start = 1;
|
||||
auto end = matrix.find('\n', start);
|
||||
|
||||
int index = 0;
|
||||
while (end != std::string_view::npos) {
|
||||
source << " tmp[" << index << "] = ";
|
||||
auto row = matrix.substr(start, end - start);
|
||||
for (int i = 0; i < row.length(); i++) {
|
||||
source << " " << row[i] << " x[" << i << "]";
|
||||
}
|
||||
source << ";" << std::endl;
|
||||
start = end + 1;
|
||||
end = matrix.find('\n', start);
|
||||
index++;
|
||||
}
|
||||
source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }"
|
||||
<< std::endl;
|
||||
source << "}" << std::endl;
|
||||
return source.str();
|
||||
}
|
||||
|
||||
void launch_hadamard(
|
||||
const array& in,
|
||||
array& out,
|
||||
int batch_size,
|
||||
int threads_per,
|
||||
const std::string kernel_name,
|
||||
float scale,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
const auto& lib_name = kernel_name.substr(1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
auto kernel = d.get_kernel(kernel_name, lib);
|
||||
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder->setBytes(&scale, sizeof(float), 2);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
|
||||
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
std::vector<array> copies;
|
||||
// Only support the last axis for now
|
||||
int axis = in.ndim() - 1;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
// TODO(alexbarron) pass strides to kernel to relax this constraint
|
||||
bool no_copy = x.flags().row_contiguous;
|
||||
if (no_copy) {
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
}
|
||||
};
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.move_shared_buffer(in_contiguous);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto [n, m] = decompose_hadamard(in.shape(axis));
|
||||
|
||||
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
|
||||
throw std::invalid_argument(
|
||||
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
|
||||
}
|
||||
|
||||
int max_radix = std::min(n, 16);
|
||||
// Use read_width 2 for m = 28 to avoid register spilling
|
||||
int read_width = (n == 2 || m == 28) ? 2 : 4;
|
||||
|
||||
std::ostringstream kname;
|
||||
kname << "hadamard_" << n * m << "_" << type_to_name(out);
|
||||
auto kernel_name = kname.str();
|
||||
auto& d = metal::device(s.device);
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto codelet = gen_hadamard_codelet(m);
|
||||
kernel_source << metal::utils() << codelet << metal::hadamard();
|
||||
kernel_source << get_template_definition(
|
||||
"n" + kernel_name,
|
||||
"hadamard_n",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
max_radix,
|
||||
read_width);
|
||||
kernel_source << get_template_definition(
|
||||
"m" + kernel_name,
|
||||
"hadamard_m",
|
||||
get_type_string(in.dtype()),
|
||||
n,
|
||||
m,
|
||||
read_width);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
|
||||
int batch_size = in.size() / n;
|
||||
int threads_per = n / max_radix;
|
||||
|
||||
if (m > 1) {
|
||||
// When m is greater than 1, we decompose the
|
||||
// computation into two uploads to the GPU:
|
||||
//
|
||||
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
|
||||
//
|
||||
// y = h48 @ x
|
||||
//
|
||||
// Upload 1:
|
||||
// tmp = a.reshape(12, 4) @ h4
|
||||
//
|
||||
// Upload 2:
|
||||
// y = h12 @ tmp
|
||||
array temp(in.shape(), in.dtype(), nullptr, {});
|
||||
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
|
||||
copies.push_back(temp);
|
||||
|
||||
launch_hadamard(
|
||||
in_contiguous,
|
||||
temp,
|
||||
batch_size,
|
||||
threads_per,
|
||||
"n" + kernel_name,
|
||||
1.0,
|
||||
s);
|
||||
|
||||
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
|
||||
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
|
||||
batch_size = in.size() / m / read_width / threads_per;
|
||||
launch_hadamard(
|
||||
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
|
||||
} else {
|
||||
launch_hadamard(
|
||||
in_contiguous,
|
||||
out,
|
||||
batch_size,
|
||||
threads_per,
|
||||
"n" + kernel_name,
|
||||
scale_,
|
||||
s);
|
||||
}
|
||||
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -293,18 +293,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
|
||||
size_t out_ndim = out.ndim();
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
|
||||
if (upd_ndim <= 1) {
|
||||
// Placeholder so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 6);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
|
87
mlx/backend/metal/jit/binary.h
Normal file
87
mlx/backend/metal/jit/binary.h
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
98
mlx/backend/metal/jit/binary_two.h
Normal file
98
mlx/backend/metal/jit/binary_two.h
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view binary_two_kernels = R"(
|
||||
template [[host_name("ss{0}")]] [[kernel]]
|
||||
void binary_ss<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vs{0}")]] [[kernel]]
|
||||
void binary_vs<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("sv{0}")]] [[kernel]]
|
||||
void binary_sv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("vv{0}")]] [[kernel]]
|
||||
void binary_vv<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g4{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 4>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5{0}")]] [[kernel]] void
|
||||
binary_g_nd<{1}, {2}, {3}, 5>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1{0}")]] [[kernel]] void
|
||||
binary_g_nd1<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t& a_stride,
|
||||
constant const size_t& b_stride,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2{0}")]] [[kernel]] void
|
||||
binary_g_nd2<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3{0}")]] [[kernel]] void
|
||||
binary_g_nd3<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("gn{0}")]] [[kernel]]
|
||||
void binary_g<{1}, {2}, {3}>(
|
||||
device const {1}* a,
|
||||
device const {1}* b,
|
||||
device {2}* c,
|
||||
device {2}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
@@ -1,25 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view gemv_masked_kernel = R"(
|
||||
template [[host_name("{name}")]] [[kernel]] void
|
||||
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
|
||||
const device {itype}* mat [[buffer(0)]],
|
||||
const device {itype}* in_vec [[buffer(1)]],
|
||||
device {itype}* out_vec [[buffer(3)]],
|
||||
const constant int& in_vec_size [[buffer(4)]],
|
||||
const constant int& out_vec_size [[buffer(5)]],
|
||||
const constant int& marix_ld [[buffer(6)]],
|
||||
const constant int& batch_ndim [[buffer(9)]],
|
||||
const constant int* batch_shape [[buffer(10)]],
|
||||
const constant size_t* vector_batch_stride [[buffer(11)]],
|
||||
const constant size_t* matrix_batch_stride [[buffer(12)]],
|
||||
const device {outm_t}* out_mask [[buffer(20)]],
|
||||
const device {opm_t}* mat_mask [[buffer(21)]],
|
||||
const device {opm_t}* vec_mask [[buffer(22)]],
|
||||
const constant int* mask_strides [[buffer(23)]],
|
||||
const constant size_t* mask_batch_strides [[buffer(24)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
)";
|
@@ -17,9 +17,6 @@ const char* unary();
|
||||
const char* binary();
|
||||
const char* binary_two();
|
||||
const char* copy();
|
||||
const char* fft();
|
||||
const char* hadamard();
|
||||
const char* quantized();
|
||||
const char* ternary();
|
||||
const char* scan();
|
||||
const char* softmax();
|
||||
@@ -33,6 +30,5 @@ const char* steel_gemm_splitk();
|
||||
const char* conv();
|
||||
const char* steel_conv();
|
||||
const char* steel_conv_general();
|
||||
const char* gemv_masked();
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -38,24 +38,12 @@ constexpr std::string_view scatter_kernels = R"(
|
||||
device mlx_atomic<{1}>* out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& out_ndim [[buffer(5)]],
|
||||
const constant int* upd_shape [[buffer(6)]],
|
||||
const constant size_t& upd_ndim [[buffer(7)]],
|
||||
const constant size_t& upd_size [[buffer(8)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
{5}
|
||||
uint2 gid [[thread_position_in_grid]]) {{
|
||||
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
|
||||
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
|
||||
updates,
|
||||
out,
|
||||
out_shape,
|
||||
out_strides,
|
||||
out_ndim,
|
||||
upd_shape,
|
||||
upd_ndim,
|
||||
upd_size,
|
||||
idx_buffers,
|
||||
gid);
|
||||
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
|
||||
}}
|
||||
|
||||
[[kernel]] void scatter{0}_{4}(
|
||||
|
81
mlx/backend/metal/jit/sort.h
Normal file
81
mlx/backend/metal/jit/sort.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view block_sort_kernels = R"(
|
||||
template [[host_name("carg_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("ncarg_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("c_{0}")]] [[kernel]] void
|
||||
block_sort<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_segment_axis [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("nc_{0}")]] [[kernel]] void
|
||||
block_sort_nc<{1}, {2}, false, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {2}* out [[buffer(1)]],
|
||||
const constant int& size_sorted_axis [[buffer(2)]],
|
||||
const constant int& stride_sorted_axis [[buffer(3)]],
|
||||
const constant int& nc_dim [[buffer(4)]],
|
||||
const device int* nc_shape [[buffer(5)]],
|
||||
const device size_t* nc_strides [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
||||
|
||||
constexpr std::string_view multiblock_sort_kernels = R"(
|
||||
template [[host_name("sort_{0}")]] [[kernel]] void
|
||||
mb_block_sort<{1}, {2}, true, {3}, {4}>(
|
||||
const device {1}* inp [[buffer(0)]],
|
||||
device {1}* out_vals [[buffer(1)]],
|
||||
device {2}* out_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& stride_sorted_axis [[buffer(4)]],
|
||||
const constant int& nc_dim [[buffer(5)]],
|
||||
const device int* nc_shape [[buffer(6)]],
|
||||
const device size_t* nc_strides [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
template [[host_name("partition_{0}")]] [[kernel]] void
|
||||
mb_block_partition<{1}, {2}, true, {3}, {4}>(
|
||||
device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals [[buffer(1)]],
|
||||
const device {2}* dev_idxs [[buffer(2)]],
|
||||
const constant int& size_sorted_axis [[buffer(3)]],
|
||||
const constant int& merge_tiles [[buffer(4)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]],
|
||||
uint3 tgp_dims [[threads_per_threadgroup]]);
|
||||
template [[host_name("merge_{0}")]] [[kernel]] void
|
||||
mb_block_merge<{1}, {2}, true, {3}, {4}>(
|
||||
const device {2}* block_partitions [[buffer(0)]],
|
||||
const device {1}* dev_vals_in [[buffer(1)]],
|
||||
const device {2}* dev_idxs_in [[buffer(2)]],
|
||||
device {1}* dev_vals_out [[buffer(3)]],
|
||||
device {2}* dev_idxs_out [[buffer(4)]],
|
||||
const constant int& size_sorted_axis [[buffer(5)]],
|
||||
const constant int& merge_tiles [[buffer(6)]],
|
||||
const constant int& num_tiles [[buffer(7)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint3 lid [[thread_position_in_threadgroup]]);
|
||||
)";
|
80
mlx/backend/metal/jit/ternary.h
Normal file
80
mlx/backend/metal/jit/ternary.h
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view ternary_kernels = R"(
|
||||
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int* shape,
|
||||
constant const size_t* a_strides,
|
||||
constant const size_t* b_strides,
|
||||
constant const size_t* c_strides,
|
||||
constant const int& ndim,
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
template [[host_name("g1_{0}")]] [[kernel]] void
|
||||
ternary_g_nd1<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t& a_strides,
|
||||
constant const size_t& b_strides,
|
||||
constant const size_t& c_strides,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
template [[host_name("g2_{0}")]] [[kernel]] void
|
||||
ternary_g_nd2<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[2],
|
||||
constant const size_t b_strides[2],
|
||||
constant const size_t c_strides[2],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g3_{0}")]] [[kernel]] void
|
||||
ternary_g_nd3<{1}, {2}>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const size_t a_strides[3],
|
||||
constant const size_t b_strides[3],
|
||||
constant const size_t c_strides[3],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g4_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 4>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[4],
|
||||
constant const size_t a_strides[4],
|
||||
constant const size_t b_strides[4],
|
||||
constant const size_t c_strides[4],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
template [[host_name("g5_{0}")]] [[kernel]] void
|
||||
ternary_g_nd<{1}, {2}, 5>(
|
||||
device const bool* a,
|
||||
device const {1}* b,
|
||||
device const {1}* c,
|
||||
device {1}* d,
|
||||
constant const int shape[5],
|
||||
constant const size_t a_strides[5],
|
||||
constant const size_t b_strides[5],
|
||||
constant const size_t c_strides[5],
|
||||
uint3 index [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
)";
|
16
mlx/backend/metal/jit/unary.h
Normal file
16
mlx/backend/metal/jit/unary.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
constexpr std::string_view unary_kernels = R"(
|
||||
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
|
||||
device const {1}* in,
|
||||
device {1}* out,
|
||||
device const int* in_shape,
|
||||
device const size_t* in_strides,
|
||||
device const int& ndim,
|
||||
uint index [[thread_position_in_grid]]);
|
||||
)";
|
@@ -1,16 +1,20 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <map>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/metal/jit/arange.h"
|
||||
#include "mlx/backend/metal/jit/binary.h"
|
||||
#include "mlx/backend/metal/jit/binary_two.h"
|
||||
#include "mlx/backend/metal/jit/copy.h"
|
||||
#include "mlx/backend/metal/jit/gemv_masked.h"
|
||||
#include "mlx/backend/metal/jit/includes.h"
|
||||
#include "mlx/backend/metal/jit/reduce.h"
|
||||
#include "mlx/backend/metal/jit/scan.h"
|
||||
#include "mlx/backend/metal/jit/softmax.h"
|
||||
#include "mlx/backend/metal/jit/sort.h"
|
||||
#include "mlx/backend/metal/jit/steel_conv.h"
|
||||
#include "mlx/backend/metal/jit/steel_gemm.h"
|
||||
#include "mlx/backend/metal/jit/ternary.h"
|
||||
#include "mlx/backend/metal/jit/unary.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
|
||||
@@ -43,81 +47,38 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto u_def = get_template_definition(
|
||||
"v" + lib_name, "unary_v", get_type_string(out_type), op);
|
||||
auto u2_def = get_template_definition(
|
||||
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
|
||||
auto g_def = get_template_definition(
|
||||
"g" + lib_name, "unary_g", get_type_string(out_type), op);
|
||||
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
|
||||
<< u_def << u2_def << g_def;
|
||||
<< fmt::format(
|
||||
unary_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
void add_binary_kernels(
|
||||
const std::string lib_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op,
|
||||
std::ostringstream& kernel_source) {
|
||||
const std::map<std::string, std::string> kernel_types = {
|
||||
{"ss", "binary_ss"},
|
||||
{"vs", "binary_vs"},
|
||||
{"sv", "binary_sv"},
|
||||
{"vv", "binary_vv"},
|
||||
{"vs2", "binary_vs2"},
|
||||
{"sv2", "binary_sv2"},
|
||||
{"vv2", "binary_vv2"},
|
||||
{"g1", "binary_g_nd1"},
|
||||
{"g2", "binary_g_nd2"},
|
||||
{"g3", "binary_g_nd3"},
|
||||
{"g4", "binary_g_nd"},
|
||||
{"g5", "binary_g_nd"},
|
||||
{"gn", "binary_g"},
|
||||
};
|
||||
for (auto [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
if (name == "g4" || name == "g5") {
|
||||
int dim = std::stoi(name.substr(1));
|
||||
template_def = get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op,
|
||||
dim);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in_type),
|
||||
get_type_string(out_type),
|
||||
op);
|
||||
}
|
||||
kernel_source << template_def;
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(2);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
|
||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
|
||||
<< fmt::format(
|
||||
binary_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -126,16 +87,20 @@ MTL::ComputePipelineState* get_binary_kernel(
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op) {
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(2);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::binary_ops()
|
||||
<< metal::binary_two();
|
||||
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
|
||||
<< metal::binary_two()
|
||||
<< fmt::format(
|
||||
binary_two_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -144,35 +109,17 @@ MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype type,
|
||||
const std::string op) {
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
const std::map<std::string, std::string> kernel_types = {
|
||||
{"v", "ternary_v"},
|
||||
{"v2", "ternary_v2"},
|
||||
{"g", "ternary_g"},
|
||||
{"g1", "ternary_g_nd1"},
|
||||
{"g2", "ternary_g_nd2"},
|
||||
{"g3", "ternary_g_nd3"},
|
||||
{"g4", "ternary_g_nd"},
|
||||
{"g5", "ternary_g_nd"},
|
||||
};
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
|
||||
for (auto [name, func] : kernel_types) {
|
||||
std::string template_def;
|
||||
if (name == "g4" || name == "g5") {
|
||||
int dim = std::stoi(name.substr(1));
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name, func, get_type_string(type), op, dim);
|
||||
} else {
|
||||
template_def = get_template_definition(
|
||||
name + "_" + lib_name, func, get_type_string(type), op);
|
||||
}
|
||||
kernel_source << template_def;
|
||||
}
|
||||
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
|
||||
<< fmt::format(
|
||||
ternary_kernels,
|
||||
lib_name,
|
||||
get_type_string(out.dtype()),
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -223,14 +170,11 @@ MTL::ComputePipelineState* get_scan_kernel(
|
||||
const std::string& kernel_name,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const std::string& reduce_type,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::string op_name = "Cum" + reduce_type;
|
||||
op_name[3] = toupper(op_name[3]);
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::scan()
|
||||
<< fmt::format(
|
||||
@@ -238,7 +182,7 @@ MTL::ComputePipelineState* get_scan_kernel(
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_name,
|
||||
op_name(out),
|
||||
inclusive,
|
||||
reverse);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
@@ -257,29 +201,14 @@ MTL::ComputePipelineState* get_sort_kernel(
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto in_type = get_type_string(in.dtype());
|
||||
auto out_type = get_type_string(out.dtype());
|
||||
kernel_source << metal::utils() << metal::sort();
|
||||
for (bool is_argsort : {true, false}) {
|
||||
std::string bool_string = is_argsort ? "true" : "false";
|
||||
std::string func_string = is_argsort ? "carg_" : "c_";
|
||||
kernel_source << get_template_definition(
|
||||
func_string + lib_name,
|
||||
"block_sort",
|
||||
in_type,
|
||||
out_type,
|
||||
bool_string,
|
||||
bn,
|
||||
tn);
|
||||
kernel_source << get_template_definition(
|
||||
"n" + func_string + lib_name,
|
||||
"block_sort_nc",
|
||||
in_type,
|
||||
out_type,
|
||||
bool_string,
|
||||
bn,
|
||||
tn);
|
||||
}
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
block_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -296,21 +225,14 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::sort();
|
||||
std::vector<std::pair<std::string, std::string>> kernel_types = {
|
||||
{"sort_", "mb_block_sort"},
|
||||
{"partition_", "mb_block_partition"},
|
||||
{"merge_", "mb_block_merge"}};
|
||||
for (auto [name, func] : kernel_types) {
|
||||
kernel_source << get_template_definition(
|
||||
name + lib_name,
|
||||
func,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
"true",
|
||||
bn,
|
||||
tn);
|
||||
}
|
||||
kernel_source << metal::utils() << metal::sort()
|
||||
<< fmt::format(
|
||||
multiblock_sort_kernels,
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(idx.dtype()),
|
||||
bn,
|
||||
tn);
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -337,14 +259,11 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out) {
|
||||
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::string op_type = op_name;
|
||||
op_type[0] = std::toupper(op_name[0]);
|
||||
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
|
||||
@@ -354,7 +273,7 @@ MTL::ComputePipelineState* get_reduce_kernel(
|
||||
lib_name,
|
||||
get_type_string(in.dtype()),
|
||||
get_type_string(out.dtype()),
|
||||
op_type);
|
||||
op_name(out));
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
@@ -503,49 +422,6 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_mat,
|
||||
int bm,
|
||||
int bn,
|
||||
int sm,
|
||||
int sn,
|
||||
int tm,
|
||||
int tn,
|
||||
bool contiguous) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
auto out_mask_type = mask_out.has_value()
|
||||
? get_type_string((*mask_out).dtype())
|
||||
: "nomask_t";
|
||||
auto op_mask_type =
|
||||
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
|
||||
kernel_source << metal::utils() << metal::gemv_masked()
|
||||
<< fmt::format(
|
||||
gemv_masked_kernel,
|
||||
"name"_a = lib_name,
|
||||
"itype"_a = get_type_string(out.dtype()),
|
||||
"outm_t"_a = out_mask_type,
|
||||
"opm_t"_a = op_mask_type,
|
||||
"bm"_a = bm,
|
||||
"bn"_a = bn,
|
||||
"sm"_a = sm,
|
||||
"sn"_a = sn,
|
||||
"tm"_a = tm,
|
||||
"tn"_a = tn,
|
||||
"trans"_a = transpose_mat ? "t_" : "",
|
||||
"nc"_a = contiguous ? "0" : "1");
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@@ -607,36 +483,4 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string& template_def) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
std::string kernel_string;
|
||||
kernel_source << metal::fft() << template_def;
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& template_def) {
|
||||
const auto& lib_name = kernel_name;
|
||||
auto lib = d.get_library(lib_name);
|
||||
if (lib == nullptr) {
|
||||
std::ostringstream kernel_source;
|
||||
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
|
||||
<< template_def;
|
||||
lib = d.get_library(lib_name, kernel_source.str());
|
||||
}
|
||||
return d.get_kernel(kernel_name, lib);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,7 +1,5 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
@@ -15,28 +13,24 @@ MTL::ComputePipelineState* get_arange_kernel(
|
||||
MTL::ComputePipelineState* get_unary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_binary_two_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype in_type,
|
||||
Dtype out_type,
|
||||
const std::string op);
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_ternary_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
Dtype type,
|
||||
const std::string op);
|
||||
const array& out);
|
||||
|
||||
MTL::ComputePipelineState* get_copy_kernel(
|
||||
metal::Device& d,
|
||||
@@ -55,7 +49,6 @@ MTL::ComputePipelineState* get_scan_kernel(
|
||||
const std::string& kernel_name,
|
||||
bool reverse,
|
||||
bool inclusive,
|
||||
const std::string& reduce_type,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
@@ -83,7 +76,6 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
|
||||
MTL::ComputePipelineState* get_reduce_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& op_name,
|
||||
const array& in,
|
||||
const array& out);
|
||||
|
||||
@@ -151,21 +143,6 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
|
||||
int n_channel_specialization,
|
||||
bool small_filter);
|
||||
|
||||
MTL::ComputePipelineState* get_gemv_masked_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const array& out,
|
||||
const std::optional<array>& mask_out,
|
||||
const std::optional<array>& mask_op,
|
||||
bool transpose_mat,
|
||||
int bm,
|
||||
int bn,
|
||||
int sm,
|
||||
int sn,
|
||||
int tm,
|
||||
int tn,
|
||||
bool contiguous);
|
||||
|
||||
MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
@@ -176,38 +153,4 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
|
||||
int wm,
|
||||
int wn);
|
||||
|
||||
MTL::ComputePipelineState* get_fft_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& hash_name,
|
||||
const metal::MTLFCList& func_consts,
|
||||
const std::string& template_def);
|
||||
|
||||
MTL::ComputePipelineState* get_quantized_kernel(
|
||||
metal::Device& d,
|
||||
const std::string& kernel_name,
|
||||
const std::string& template_def);
|
||||
|
||||
// Create a GPU kernel template definition for JIT compilation
|
||||
template <typename... Args>
|
||||
std::string
|
||||
get_template_definition(std::string name, std::string func, Args... args) {
|
||||
std::ostringstream s;
|
||||
s << func << "<";
|
||||
bool first = true;
|
||||
auto add_arg = [&s, &first](const auto& arg) {
|
||||
if (!first) {
|
||||
s << ", ";
|
||||
}
|
||||
first = false;
|
||||
s << arg;
|
||||
};
|
||||
(add_arg(args), ...);
|
||||
s << ">";
|
||||
std::string base_string = R"(
|
||||
template [[host_name("{0}")]] [[kernel]] decltype({1}) {1};
|
||||
)";
|
||||
return fmt::format(base_string, name, s.str());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,15 +1,66 @@
|
||||
set(
|
||||
BASE_HEADERS
|
||||
HEADERS
|
||||
bf16.h
|
||||
bf16_math.h
|
||||
complex.h
|
||||
defines.h
|
||||
expm1f.h
|
||||
utils.h
|
||||
steel/conv/params.h
|
||||
)
|
||||
|
||||
set(
|
||||
KERNELS
|
||||
"arg_reduce"
|
||||
"conv"
|
||||
"fft"
|
||||
"gemv"
|
||||
"quantized"
|
||||
"random"
|
||||
"rms_norm"
|
||||
"layer_norm"
|
||||
"rope"
|
||||
"scaled_dot_product_attention"
|
||||
)
|
||||
|
||||
if (NOT MLX_METAL_JIT)
|
||||
set(
|
||||
KERNELS
|
||||
${KERNELS}
|
||||
"arange"
|
||||
"binary"
|
||||
"binary_two"
|
||||
"unary"
|
||||
"ternary"
|
||||
"copy"
|
||||
"softmax"
|
||||
"sort"
|
||||
"scan"
|
||||
"reduce"
|
||||
)
|
||||
set(
|
||||
HEADERS
|
||||
${HEADERS}
|
||||
atomic.h
|
||||
arange.h
|
||||
unary_ops.h
|
||||
unary.h
|
||||
binary_ops.h
|
||||
binary.h
|
||||
ternary.h
|
||||
copy.h
|
||||
softmax.h
|
||||
sort.h
|
||||
scan.h
|
||||
reduction/ops.h
|
||||
reduction/reduce_init.h
|
||||
reduction/reduce_all.h
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h
|
||||
)
|
||||
endif()
|
||||
|
||||
function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
|
||||
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
|
||||
if(MLX_METAL_DEBUG)
|
||||
set(METAL_FLAGS ${METAL_FLAGS}
|
||||
-gline-tables-only
|
||||
@@ -21,7 +72,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
-o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
DEPENDS ${SRCFILE} ${DEPS}
|
||||
OUTPUT ${TARGET}.air
|
||||
COMMENT "Building ${TARGET}.air"
|
||||
VERBATIM
|
||||
@@ -30,100 +81,49 @@ endfunction(build_kernel_base)
|
||||
|
||||
function(build_kernel KERNEL)
|
||||
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE)
|
||||
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}")
|
||||
endfunction(build_kernel)
|
||||
|
||||
build_kernel(arg_reduce)
|
||||
build_kernel(conv steel/conv/params.h)
|
||||
build_kernel(gemv steel/utils.h)
|
||||
build_kernel(layer_norm)
|
||||
build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(
|
||||
scaled_dot_product_attention
|
||||
scaled_dot_product_attention_params.h
|
||||
steel/defines.h
|
||||
steel/gemm/transforms.h
|
||||
steel/utils.h
|
||||
)
|
||||
|
||||
set(
|
||||
STEEL_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/conv/conv.h
|
||||
steel/conv/loader.h
|
||||
steel/conv/loaders/loader_channel_l.h
|
||||
steel/conv/loaders/loader_channel_n.h
|
||||
steel/conv/loaders/loader_general.h
|
||||
steel/conv/kernels/steel_conv.h
|
||||
steel/conv/kernels/steel_conv_general.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
)
|
||||
foreach(KERNEL ${KERNELS})
|
||||
build_kernel(${KERNEL})
|
||||
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
if (NOT MLX_METAL_JIT)
|
||||
build_kernel(arange arange.h)
|
||||
build_kernel(binary binary.h binary_ops.h)
|
||||
build_kernel(binary_two binary_two.h)
|
||||
build_kernel(copy copy.h)
|
||||
build_kernel(
|
||||
fft
|
||||
fft.h
|
||||
fft/radix.h
|
||||
fft/readwrite.h
|
||||
)
|
||||
build_kernel(
|
||||
reduce
|
||||
atomic.h
|
||||
reduction/ops.h
|
||||
reduction/reduce_init.h
|
||||
reduction/reduce_all.h
|
||||
reduction/reduce_col.h
|
||||
reduction/reduce_row.h
|
||||
)
|
||||
build_kernel(
|
||||
quantized
|
||||
quantized.h
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(scan scan.h)
|
||||
build_kernel(softmax softmax.h)
|
||||
build_kernel(sort sort.h)
|
||||
build_kernel(ternary ternary.h ternary_ops.h)
|
||||
build_kernel(unary unary.h unary_ops.h)
|
||||
build_kernel(
|
||||
steel/conv/kernels/steel_conv
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/conv/kernels/steel_conv_general
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/gemm/kernels/steel_gemm_fused
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/gemm/kernels/steel_gemm_masked
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(
|
||||
steel/gemm/kernels/steel_gemm_splitk
|
||||
${STEEL_HEADERS}
|
||||
)
|
||||
build_kernel(gemv_masked steel/utils.h)
|
||||
set(
|
||||
STEEL_KERNELS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
|
||||
)
|
||||
set(
|
||||
STEEL_HEADERS
|
||||
steel/defines.h
|
||||
steel/utils.h
|
||||
steel/conv/conv.h
|
||||
steel/conv/loader.h
|
||||
steel/conv/loaders/loader_channel_l.h
|
||||
steel/conv/loaders/loader_channel_n.h
|
||||
steel/conv/loaders/loader_general.h
|
||||
steel/conv/kernels/steel_conv.h
|
||||
steel/conv/kernels/steel_conv_general.h
|
||||
steel/gemm/gemm.h
|
||||
steel/gemm/mma.h
|
||||
steel/gemm/loader.h
|
||||
steel/gemm/transforms.h
|
||||
steel/gemm/kernels/steel_gemm_fused.h
|
||||
steel/gemm/kernels/steel_gemm_masked.h
|
||||
steel/gemm/kernels/steel_gemm_splitk.h
|
||||
)
|
||||
foreach(KERNEL ${STEEL_KERNELS})
|
||||
cmake_path(GET KERNEL STEM TARGET)
|
||||
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
|
||||
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
||||
|
@@ -6,7 +6,7 @@
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
|
||||
|
||||
typedef bfloat bfloat16_t;
|
||||
|
||||
|
@@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
|
||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
||||
}
|
||||
|
||||
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
|
||||
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
|
||||
|
||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
||||
|
@@ -36,39 +36,6 @@ template <typename T, typename U, typename Op>
|
||||
c[index] = Op()(a[index], b[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_sv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
c[offset] = Op()(a[0], b[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vs2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
c[offset] = Op()(a[offset], b[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
|
@@ -4,94 +4,148 @@
|
||||
#include <metal_math>
|
||||
|
||||
// clang-format off
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
|
||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
|
||||
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template \
|
||||
[[host_name(name)]] [[kernel]] void binary_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_all(op, int64, int64_t, int64_t)
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name("g" #dims name)]] [[kernel]] void \
|
||||
binary_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
instantiate_binary_all(op, float32, float, float) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name("g1" name)]] [[kernel]] void \
|
||||
binary_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g2" name)]] [[kernel]] void \
|
||||
binary_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g3" name)]] [[kernel]] void \
|
||||
binary_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_integer(op) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
||||
instantiate_binary_float(op)
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name("gn" name)]] [[kernel]] void binary_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_types_bool(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, bool) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, bool) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, bool) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, bool) \
|
||||
instantiate_binary_all(op, int8, int8_t, bool) \
|
||||
instantiate_binary_all(op, int16, int16_t, bool) \
|
||||
instantiate_binary_all(op, int32, int32_t, bool) \
|
||||
instantiate_binary_all(op, int64, int64_t, bool) \
|
||||
instantiate_binary_all(op, float16, half, bool) \
|
||||
instantiate_binary_all(op, float32, float, bool) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, bool)
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g(#name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd(#name #tname, itype, otype, op)
|
||||
|
||||
instantiate_binary_types(Add)
|
||||
instantiate_binary_types(Divide)
|
||||
instantiate_binary_types_bool(Equal)
|
||||
instantiate_binary_types_bool(Greater)
|
||||
instantiate_binary_types_bool(GreaterEqual)
|
||||
instantiate_binary_types_bool(Less)
|
||||
instantiate_binary_types_bool(LessEqual)
|
||||
instantiate_binary_types_bool(NotEqual)
|
||||
instantiate_binary_float(LogAddExp)
|
||||
instantiate_binary_types(Maximum)
|
||||
instantiate_binary_types(Minimum)
|
||||
instantiate_binary_types(Multiply)
|
||||
instantiate_binary_types(Subtract)
|
||||
instantiate_binary_types(Power)
|
||||
instantiate_binary_types(Remainder)
|
||||
instantiate_binary_float(ArcTan2)
|
||||
#define instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op)
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_integer(name, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
#define instantiate_binary_types_bool(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, bool, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, bool, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, bool, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, bool, op) \
|
||||
instantiate_binary_all(name, float16, half, bool, op) \
|
||||
instantiate_binary_all(name, float32, float, bool, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, bool, op)
|
||||
|
||||
instantiate_binary_types(add, Add)
|
||||
instantiate_binary_types(div, Divide)
|
||||
instantiate_binary_types_bool(eq, Equal)
|
||||
instantiate_binary_types_bool(ge, Greater)
|
||||
instantiate_binary_types_bool(geq, GreaterEqual)
|
||||
instantiate_binary_types_bool(le, Less)
|
||||
instantiate_binary_types_bool(leq, LessEqual)
|
||||
instantiate_binary_types_bool(neq, NotEqual)
|
||||
instantiate_binary_float(lae, LogAddExp)
|
||||
instantiate_binary_types(max, Maximum)
|
||||
instantiate_binary_types(min, Minimum)
|
||||
instantiate_binary_types(mul, Multiply)
|
||||
instantiate_binary_types(sub, Subtract)
|
||||
instantiate_binary_types(pow, Power)
|
||||
instantiate_binary_types(rem, Remainder)
|
||||
instantiate_binary_float(arctan2, ArcTan2)
|
||||
|
||||
// NaNEqual only needed for floating point types with boolean output
|
||||
instantiate_binary_all(NaNEqual, float16, half, bool)
|
||||
instantiate_binary_all(NaNEqual, float32, float, bool)
|
||||
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
|
||||
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool)
|
||||
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
|
||||
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
|
||||
|
||||
instantiate_binary_all(LogicalOr, bool_, bool, bool)
|
||||
instantiate_binary_all(LogicalAnd, bool_, bool, bool)
|
||||
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
|
||||
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
|
||||
|
||||
// Bitwise ops only need integer types and bool (except for l/r shift)
|
||||
instantiate_binary_integer(BitwiseAnd)
|
||||
instantiate_binary_all(BitwiseAnd, bool_, bool, bool)
|
||||
instantiate_binary_integer(BitwiseOr)
|
||||
instantiate_binary_all(BitwiseOr, bool_, bool, bool)
|
||||
instantiate_binary_integer(BitwiseXor)
|
||||
instantiate_binary_all(BitwiseXor, bool_, bool, bool)
|
||||
instantiate_binary_integer(LeftShift)
|
||||
instantiate_binary_integer(RightShift) // clang-format on
|
||||
instantiate_binary_integer(bitwise_and, BitwiseAnd)
|
||||
instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd)
|
||||
instantiate_binary_integer(bitwise_or, BitwiseOr)
|
||||
instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
|
||||
instantiate_binary_integer(bitwise_xor, BitwiseXor)
|
||||
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
|
||||
instantiate_binary_integer(left_shift, LeftShift)
|
||||
instantiate_binary_integer(right_shift, RightShift) // clang-format on
|
||||
|
@@ -48,48 +48,6 @@ template <typename T, typename U, typename Op>
|
||||
d[index] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_sv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto out = Op()(a[0], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vs2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto out = Op()(a[offset], b[0]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_vv2(
|
||||
device const T* a,
|
||||
device const T* b,
|
||||
device U* c,
|
||||
device U* d,
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
auto out = Op()(a[offset], b[offset]);
|
||||
c[offset] = out[0];
|
||||
d[offset] = out[1];
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op>
|
||||
[[kernel]] void binary_g_nd1(
|
||||
device const T* a,
|
||||
|
@@ -7,37 +7,99 @@
|
||||
#include "mlx/backend/metal/kernels/binary_ops.h"
|
||||
#include "mlx/backend/metal/kernels/binary_two.h"
|
||||
|
||||
#define instantiate_binary_all(op, tname, itype, otype) \
|
||||
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
|
||||
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
|
||||
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
|
||||
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
|
||||
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
|
||||
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
|
||||
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
|
||||
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
|
||||
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
|
||||
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
|
||||
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
|
||||
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
|
||||
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
|
||||
#define instantiate_binary(name, itype, otype, op, bopt) \
|
||||
template [[host_name(name)]] [[kernel]] void \
|
||||
binary_##bopt<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_binary_float(op) \
|
||||
instantiate_binary_all(op, float16, half, half) \
|
||||
instantiate_binary_all(op, float32, float, float) \
|
||||
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
|
||||
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
|
||||
template [[host_name("g" #dims name)]] [[kernel]] void \
|
||||
binary_g_nd<itype, otype, op, dims>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int shape[dims], \
|
||||
constant const size_t a_strides[dims], \
|
||||
constant const size_t b_strides[dims], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_types(op) \
|
||||
instantiate_binary_all(op, bool_, bool, bool) \
|
||||
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
|
||||
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
|
||||
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
|
||||
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
|
||||
instantiate_binary_all(op, int8, int8_t, int8_t) \
|
||||
instantiate_binary_all(op, int16, int16_t, int16_t) \
|
||||
instantiate_binary_all(op, int32, int32_t, int32_t) \
|
||||
instantiate_binary_all(op, int64, int64_t, int64_t) \
|
||||
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
|
||||
instantiate_binary_float(op)
|
||||
#define instantiate_binary_g_nd(name, itype, otype, op) \
|
||||
template [[host_name("g1" name)]] [[kernel]] void \
|
||||
binary_g_nd1<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t& a_stride, \
|
||||
constant const size_t& b_stride, \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g2" name)]] [[kernel]] void \
|
||||
binary_g_nd2<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[2], \
|
||||
constant const size_t b_strides[2], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g3" name)]] [[kernel]] void \
|
||||
binary_g_nd3<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const size_t a_strides[3], \
|
||||
constant const size_t b_strides[3], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 4) \
|
||||
instantiate_binary_g_dim(name, itype, otype, op, 5)
|
||||
|
||||
instantiate_binary_types(DivMod) // clang-format on
|
||||
#define instantiate_binary_g(name, itype, otype, op) \
|
||||
template [[host_name("gn" name)]] [[kernel]] void \
|
||||
binary_g<itype, otype, op>( \
|
||||
device const itype* a, \
|
||||
device const itype* b, \
|
||||
device otype* c, \
|
||||
device otype* d, \
|
||||
constant const int* shape, \
|
||||
constant const size_t* a_strides, \
|
||||
constant const size_t* b_strides, \
|
||||
constant const int& ndim, \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]);
|
||||
|
||||
#define instantiate_binary_all(name, tname, itype, otype, op) \
|
||||
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
|
||||
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
|
||||
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
|
||||
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
|
||||
instantiate_binary_g(#name #tname, itype, otype, op) \
|
||||
instantiate_binary_g_nd(#name #tname, itype, otype, op)
|
||||
|
||||
#define instantiate_binary_float(name, op) \
|
||||
instantiate_binary_all(name, float16, half, half, op) \
|
||||
instantiate_binary_all(name, float32, float, float, op) \
|
||||
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
|
||||
|
||||
#define instantiate_binary_types(name, op) \
|
||||
instantiate_binary_all(name, bool_, bool, bool, op) \
|
||||
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
|
||||
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
|
||||
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
|
||||
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
|
||||
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
|
||||
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
|
||||
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
|
||||
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
|
||||
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
|
||||
instantiate_binary_float(name, op)
|
||||
|
||||
instantiate_binary_types(divmod, DivMod) // clang-format on
|
||||
|
@@ -344,12 +344,12 @@ winograd_conv_2d_weight_transform(
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize G matrix
|
||||
simdgroup_matrix<float, 8, 8> G;
|
||||
simdgroup_matrix<T, 8, 8> G;
|
||||
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
|
||||
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
|
||||
|
||||
// Initialize Gt matrix
|
||||
simdgroup_matrix<float, 8, 8> Gt;
|
||||
simdgroup_matrix<T, 8, 8> Gt;
|
||||
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
|
||||
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
|
||||
|
||||
@@ -381,15 +381,15 @@ winograd_conv_2d_weight_transform(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for (int c = 0; c < BC; ++c) {
|
||||
simdgroup_matrix<float, 8, 8> g;
|
||||
simdgroup_matrix<T, 8, 8> g;
|
||||
g.thread_elements()[0] =
|
||||
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
|
||||
g.thread_elements()[1] =
|
||||
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
|
||||
|
||||
simdgroup_matrix<float, 8, 8> g_out = (G * g) * Gt;
|
||||
wt_out_0[c * O] = static_cast<T>(g_out.thread_elements()[0]);
|
||||
wt_out_1[c * O] = static_cast<T>(g_out.thread_elements()[1]);
|
||||
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
|
||||
wt_out_0[c * O] = g_out.thread_elements()[0];
|
||||
wt_out_1[c * O] = g_out.thread_elements()[1];
|
||||
}
|
||||
|
||||
wt_in += BC;
|
||||
@@ -433,12 +433,12 @@ winograd_conv_2d_input_transform(
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize B matrix
|
||||
simdgroup_matrix<float, 8, 8> B;
|
||||
simdgroup_matrix<T, 8, 8> B;
|
||||
B.thread_elements()[0] = WGT::in_transform[sm][sn];
|
||||
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
|
||||
|
||||
// Initialize Bt matrix
|
||||
simdgroup_matrix<float, 8, 8> Bt;
|
||||
simdgroup_matrix<T, 8, 8> Bt;
|
||||
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
|
||||
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
|
||||
|
||||
@@ -493,13 +493,13 @@ winograd_conv_2d_input_transform(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
|
||||
simdgroup_matrix<float, 8, 8> I;
|
||||
simdgroup_matrix<T, 8, 8> I;
|
||||
I.thread_elements()[0] = Is[sm][sn][c];
|
||||
I.thread_elements()[1] = Is[sm][sn + 1][c];
|
||||
|
||||
simdgroup_matrix<float, 8, 8> I_out = (Bt * I) * B;
|
||||
inp_out_0[c] = static_cast<T>(I_out.thread_elements()[0]);
|
||||
inp_out_1[c] = static_cast<T>(I_out.thread_elements()[1]);
|
||||
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B;
|
||||
inp_out_0[c] = I_out.thread_elements()[0];
|
||||
inp_out_1[c] = I_out.thread_elements()[1];
|
||||
}
|
||||
|
||||
inp_in += BC;
|
||||
@@ -543,12 +543,12 @@ winograd_conv_2d_output_transform(
|
||||
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
||||
|
||||
// Initialize A matrix
|
||||
simdgroup_matrix<float, 8, 8> B;
|
||||
simdgroup_matrix<T, 8, 8> B;
|
||||
B.thread_elements()[0] = WGT::out_transform[sm][sn];
|
||||
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
|
||||
|
||||
// Initialize At matrix
|
||||
simdgroup_matrix<float, 8, 8> Bt;
|
||||
simdgroup_matrix<T, 8, 8> Bt;
|
||||
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
|
||||
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
|
||||
|
||||
@@ -597,16 +597,16 @@ winograd_conv_2d_output_transform(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// Do transform and store the result
|
||||
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
|
||||
simdgroup_matrix<float, 8, 8> O_mat;
|
||||
simdgroup_matrix<T, 8, 8> O_mat;
|
||||
O_mat.thread_elements()[0] = out_in_0[c];
|
||||
O_mat.thread_elements()[1] = out_in_1[c];
|
||||
|
||||
simdgroup_matrix<float, 8, 8> O_out = (Bt * (O_mat * B));
|
||||
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
|
||||
if ((sm < M) && (sn < M)) {
|
||||
Os[sm][sn][c] = static_cast<T>(O_out.thread_elements()[0]);
|
||||
Os[sm][sn][c] = O_out.thread_elements()[0];
|
||||
}
|
||||
if ((sm < M) && ((sn + 1) < M)) {
|
||||
Os[sm][sn + 1][c] = static_cast<T>(O_out.thread_elements()[1]);
|
||||
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -650,5 +650,4 @@ winograd_conv_2d_output_transform(
|
||||
|
||||
// clang-format off
|
||||
instantiate_winograd_conv_2d(float32, float);
|
||||
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
|
||||
instantiate_winograd_conv_2d(float16, half); // clang-format on
|
||||
|
@@ -16,26 +16,6 @@ template <typename T, typename U>
|
||||
dst[index] = static_cast<U>(src[index]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_s2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[0]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_v2(
|
||||
device const T* src [[buffer(0)]],
|
||||
device U* dst [[buffer(1)]],
|
||||
uint2 index [[thread_position_in_grid]],
|
||||
uint2 grid_dim [[threads_per_grid]]) {
|
||||
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
||||
dst[offset] = static_cast<U>(src[offset]);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
[[kernel]] void copy_g_nd1(
|
||||
device const T* src [[buffer(0)]],
|
||||
|
@@ -5,23 +5,95 @@
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/copy.h"
|
||||
|
||||
#define instantiate_copy(name, itype, otype, ctype) \
|
||||
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_dim(name, itype, otype, dims) \
|
||||
template [[host_name("g" #dims "_" name)]] [[kernel]] void \
|
||||
copy_g_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
|
||||
copy_gg_nd<itype, otype, dims>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_g_nd(name, itype, otype) \
|
||||
template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint2 index [[thread_position_in_grid]], \
|
||||
uint2 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gg1_" name )]] [[kernel]] void \
|
||||
copy_gg_nd1<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t& src_stride [[buffer(3)]], \
|
||||
constant const int64_t& dst_stride [[buffer(4)]], \
|
||||
uint index [[thread_position_in_grid]]); \
|
||||
template [[host_name("gg2_" name)]] [[kernel]] void \
|
||||
copy_gg_nd2<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint2 index [[thread_position_in_grid]]); \
|
||||
template [[host_name("gg3_" name)]] [[kernel]] void \
|
||||
copy_gg_nd3<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
uint3 index [[thread_position_in_grid]]); \
|
||||
instantiate_copy_g_dim(name, itype, otype, 4) \
|
||||
instantiate_copy_g_dim(name, itype, otype, 5)
|
||||
|
||||
#define instantiate_copy_g(name, itype, otype) \
|
||||
template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]], \
|
||||
uint3 grid_dim [[threads_per_grid]]); \
|
||||
template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
|
||||
device const itype* src [[buffer(0)]], \
|
||||
device otype* dst [[buffer(1)]], \
|
||||
constant const int* src_shape [[buffer(2)]], \
|
||||
constant const int64_t* src_strides [[buffer(3)]], \
|
||||
constant const int64_t* dst_strides [[buffer(4)]], \
|
||||
constant const int& ndim [[buffer(5)]], \
|
||||
uint3 index [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_copy_all(tname, itype, otype) \
|
||||
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
|
||||
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
|
||||
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
|
||||
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
|
||||
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
|
||||
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
|
||||
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
|
||||
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
|
||||
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
|
||||
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
|
||||
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \
|
||||
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \
|
||||
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
|
||||
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
|
||||
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
|
||||
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
|
||||
instantiate_copy("s_copy" #tname, itype, otype, s) \
|
||||
instantiate_copy("v_copy" #tname, itype, otype, v) \
|
||||
instantiate_copy_g("copy" #tname, itype, otype) \
|
||||
instantiate_copy_g_nd("copy" #tname, itype, otype)
|
||||
|
||||
#define instantiate_copy_itype(itname, itype) \
|
||||
instantiate_copy_all(itname ##bool_, itype, bool) \
|
||||
|
@@ -13,11 +13,3 @@ static MTL_CONST constexpr int REDUCE_N_READS = 16;
|
||||
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_N_READS = 4;
|
||||
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
||||
|
||||
// Instantiate a templated kernel.
|
||||
// Extra args are used as template parameters:
|
||||
// e.g. instantiate_kernel(binary_int, binary, a, b) ->
|
||||
// [[host_name(binary_int)]] [kernel] binary<a, b>
|
||||
#define instantiate_kernel(name, func, ...) \
|
||||
template [[host_name( \
|
||||
name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
|
||||
|
@@ -83,7 +83,6 @@ float expm1f(float a) {
|
||||
r = expm1f_scaled_unchecked(a, 1.0f);
|
||||
/* handle severe overflow and underflow */
|
||||
if (abs(a - 1.0f) > 88.0f) {
|
||||
r = pow(2, a);
|
||||
r = fma(r, r, -1.0f);
|
||||
}
|
||||
return r;
|
||||
|
@@ -1,486 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Metal FFT using Stockham's algorithm
|
||||
//
|
||||
// References:
|
||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||
|
||||
#include <metal_common>
|
||||
|
||||
#include "mlx/backend/metal/kernels/fft/radix.h"
|
||||
#include "mlx/backend/metal/kernels/fft/readwrite.h"
|
||||
#include "mlx/backend/metal/kernels/steel/defines.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MAX_RADIX 13
|
||||
// Reached when elems_per_thread_ = 6, max_radix = 13
|
||||
// and some threads have to do 3 radix 6s requiring 18 float2s.
|
||||
#define MAX_OUTPUT_SIZE 18
|
||||
|
||||
// Specialize for a particular value of N at runtime
|
||||
STEEL_CONST bool inv_ [[function_constant(0)]];
|
||||
STEEL_CONST bool is_power_of_2_ [[function_constant(1)]];
|
||||
STEEL_CONST int elems_per_thread_ [[function_constant(2)]];
|
||||
// rader_m = n / rader_n
|
||||
STEEL_CONST int rader_m_ [[function_constant(3)]];
|
||||
// Stockham steps
|
||||
STEEL_CONST int radix_13_steps_ [[function_constant(4)]];
|
||||
STEEL_CONST int radix_11_steps_ [[function_constant(5)]];
|
||||
STEEL_CONST int radix_8_steps_ [[function_constant(6)]];
|
||||
STEEL_CONST int radix_7_steps_ [[function_constant(7)]];
|
||||
STEEL_CONST int radix_6_steps_ [[function_constant(8)]];
|
||||
STEEL_CONST int radix_5_steps_ [[function_constant(9)]];
|
||||
STEEL_CONST int radix_4_steps_ [[function_constant(10)]];
|
||||
STEEL_CONST int radix_3_steps_ [[function_constant(11)]];
|
||||
STEEL_CONST int radix_2_steps_ [[function_constant(12)]];
|
||||
// Rader steps
|
||||
STEEL_CONST int rader_13_steps_ [[function_constant(13)]];
|
||||
STEEL_CONST int rader_11_steps_ [[function_constant(14)]];
|
||||
STEEL_CONST int rader_8_steps_ [[function_constant(15)]];
|
||||
STEEL_CONST int rader_7_steps_ [[function_constant(16)]];
|
||||
STEEL_CONST int rader_6_steps_ [[function_constant(17)]];
|
||||
STEEL_CONST int rader_5_steps_ [[function_constant(18)]];
|
||||
STEEL_CONST int rader_4_steps_ [[function_constant(19)]];
|
||||
STEEL_CONST int rader_3_steps_ [[function_constant(20)]];
|
||||
STEEL_CONST int rader_2_steps_ [[function_constant(21)]];
|
||||
|
||||
// See "radix.h" for radix codelets
|
||||
typedef void (*RadixFunc)(thread float2*, thread float2*);
|
||||
|
||||
// Perform a single radix n butterfly with appropriate twiddles
|
||||
template <int radix, RadixFunc radix_func>
|
||||
METAL_FUNC void radix_butterfly(
|
||||
int i,
|
||||
int p,
|
||||
thread float2* x,
|
||||
thread short* indices,
|
||||
thread float2* y) {
|
||||
// i: the index in the overall DFT that we're processing.
|
||||
// p: the size of the DFTs we're merging at this step.
|
||||
// m: how many threads are working on this DFT.
|
||||
int k, j;
|
||||
|
||||
// Use faster bitwise operations when working with powers of two
|
||||
constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
|
||||
if (radix_p_2 && is_power_of_2_) {
|
||||
constexpr short power = __builtin_ctz(radix);
|
||||
k = i & (p - 1);
|
||||
j = ((i - k) << power) + k;
|
||||
} else {
|
||||
k = i % p;
|
||||
j = (i / p) * radix * p + k;
|
||||
}
|
||||
|
||||
// Apply twiddles
|
||||
if (p > 1) {
|
||||
float2 twiddle_1 = get_twiddle(k, radix * p);
|
||||
float2 twiddle = twiddle_1;
|
||||
x[1] = complex_mul(x[1], twiddle);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int t = 2; t < radix; t++) {
|
||||
twiddle = complex_mul(twiddle, twiddle_1);
|
||||
x[t] = complex_mul(x[t], twiddle);
|
||||
}
|
||||
}
|
||||
|
||||
radix_func(x, y);
|
||||
|
||||
STEEL_PRAGMA_UNROLL
|
||||
for (int t = 0; t < radix; t++) {
|
||||
indices[t] = j + t * p;
|
||||
}
|
||||
}
|
||||
|
||||
// Perform all the radix steps required for a
|
||||
// particular radix size n.
|
||||
template <int radix, RadixFunc radix_func>
|
||||
METAL_FUNC void radix_n_steps(
|
||||
int i,
|
||||
thread int* p,
|
||||
int m,
|
||||
int n,
|
||||
int num_steps,
|
||||
thread float2* inputs,
|
||||
thread short* indices,
|
||||
thread float2* values,
|
||||
threadgroup float2* buf) {
|
||||
int m_r = n / radix;
|
||||
// When combining different sized radices, we have to do
|
||||
// multiple butterflies in a single thread.
|
||||
// E.g. n = 28 = 4 * 7
|
||||
// 4 threads, 7 elems_per_thread
|
||||
// All threads do 1 radix7 butterfly.
|
||||
// 3 threads do 2 radix4 butterflies.
|
||||
// 1 thread does 1 radix4 butterfly.
|
||||
int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix;
|
||||
|
||||
int index = 0;
|
||||
int r_index = 0;
|
||||
for (int s = 0; s < num_steps; s++) {
|
||||
for (int t = 0; t < max_radices_per_thread; t++) {
|
||||
index = i + t * m;
|
||||
if (index < m_r) {
|
||||
for (int r = 0; r < radix; r++) {
|
||||
inputs[r] = buf[index + r * m_r];
|
||||
}
|
||||
radix_butterfly<radix, radix_func>(
|
||||
index, *p, inputs, indices + t * radix, values + t * radix);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until all threads have read their inputs into thread local mem
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int t = 0; t < max_radices_per_thread; t++) {
|
||||
index = i + t * m;
|
||||
if (index < m_r) {
|
||||
for (int r = 0; r < radix; r++) {
|
||||
r_index = t * radix + r;
|
||||
buf[indices[r_index]] = values[r_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until all threads have written back to threadgroup mem
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
*p *= radix;
|
||||
}
|
||||
}
|
||||
|
||||
#define RADIX_STEP(radix, radix_func, num_steps) \
|
||||
radix_n_steps<radix, radix_func>( \
|
||||
fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
|
||||
|
||||
template <bool rader = false>
|
||||
METAL_FUNC void
|
||||
perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) {
|
||||
float2 inputs[MAX_RADIX];
|
||||
short indices[MAX_OUTPUT_SIZE];
|
||||
float2 values[MAX_OUTPUT_SIZE];
|
||||
|
||||
RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_);
|
||||
RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_);
|
||||
RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_);
|
||||
RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_);
|
||||
RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_);
|
||||
RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_);
|
||||
RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_);
|
||||
RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_);
|
||||
RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_);
|
||||
}
|
||||
|
||||
// Each FFT is computed entirely in shared GPU memory.
|
||||
//
|
||||
// N is decomposed into radix-n DFTs:
|
||||
// e.g. 128 = 2 * 4 * 4 * 4
|
||||
template <int tg_mem_size, typename in_T, typename out_T>
|
||||
[[kernel]] void fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
|
||||
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
int fft_idx = elem.z; // Thread index in DFT
|
||||
int m = grid.z; // Threads per DFT
|
||||
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
|
||||
threadgroup float2* buf = &shared_in[tg_idx];
|
||||
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write();
|
||||
}
|
||||
|
||||
template <int tg_mem_size, typename in_T, typename out_T>
|
||||
[[kernel]] void rader_fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
const device float2* raders_b_q [[buffer(2)]],
|
||||
const device short* raders_g_q [[buffer(3)]],
|
||||
const device short* raders_g_minus_q [[buffer(4)]],
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
constant const int& rader_n,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Use Rader's algorithm to compute fast FFTs
|
||||
// when a prime factor `p` of `n` is greater than 13 but
|
||||
// has `p - 1` Stockham decomposable into to prime factors <= 13.
|
||||
//
|
||||
// E.g. n = 102
|
||||
// = 2 * 3 * 17
|
||||
// . = 2 * 3 * RADER(16)
|
||||
// . = 2 * 3 * RADER(4 * 4)
|
||||
//
|
||||
// In numpy:
|
||||
// x_perm = x[g_q]
|
||||
// y = np.fft.fft(x_perm) * b_q
|
||||
// z = np.fft.ifft(y) + x[0]
|
||||
// out = z[g_minus_q]
|
||||
// out[0] = x[1:].sum()
|
||||
//
|
||||
// Where the g_q and g_minus_q are permutations formed
|
||||
// by the group under multiplicative modulo N using the
|
||||
// primitive root of N and b_q is a constant.
|
||||
// See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm
|
||||
//
|
||||
// Rader's uses fewer operations than Bluestein's and so
|
||||
// is more accurate. It's also faster in most cases.
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
|
||||
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load();
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = grid.z;
|
||||
|
||||
int fft_idx = elem.z;
|
||||
int tg_idx = elem.y * n;
|
||||
threadgroup float2* buf = &shared_in[tg_idx];
|
||||
|
||||
// rader_m = n / rader_n;
|
||||
int rader_m = rader_m_;
|
||||
|
||||
// We have to load two x_0s for each thread since sometimes
|
||||
// elems_per_thread_ crosses a boundary.
|
||||
// E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4
|
||||
// 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8
|
||||
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
|
||||
short x_0_index =
|
||||
metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1);
|
||||
float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]};
|
||||
|
||||
// Do the Rader permutation in shared memory
|
||||
float2 temp[MAX_RADIX];
|
||||
int max_index = n - rader_m - 1;
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
short g_q = raders_g_q[index / rader_m];
|
||||
temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
buf[index + rader_m] = temp[e];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Rader FFT on x[rader_m:]
|
||||
int p = 1;
|
||||
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
|
||||
|
||||
// x_1 + ... + x_n is computed for us in the first FFT step so
|
||||
// we save it in the first rader_m indices of the array for later.
|
||||
int x_sum_index = metal::min(fft_idx, rader_m - 1);
|
||||
buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)];
|
||||
|
||||
float2 inv = {1.0f, -1.0f};
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
short interleaved_index =
|
||||
index / rader_m + (index % rader_m) * (rader_n - 1);
|
||||
temp[e] = complex_mul(
|
||||
buf[rader_m + interleaved_index],
|
||||
raders_b_q[interleaved_index % (rader_n - 1)]);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
buf[rader_m + index] = temp[e] * inv;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Rader IFFT on x[rader_m:]
|
||||
p = 1;
|
||||
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
|
||||
|
||||
float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1);
|
||||
short diff_index = index / (rader_n - 1) - x_0_index;
|
||||
temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
|
||||
}
|
||||
|
||||
// Use the sum of elements that was computed in the first FFT
|
||||
float2 x_sum = buf[x_0_index] + x_0[0];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int e = 0; e < elems_per_thread_; e++) {
|
||||
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
|
||||
short g_q_index = index % (rader_n - 1);
|
||||
short g_q = raders_g_minus_q[g_q_index];
|
||||
short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
|
||||
buf[out_index] = temp[e];
|
||||
}
|
||||
|
||||
buf[x_0_index * rader_n] = x_sum;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
p = rader_n;
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write();
|
||||
}
|
||||
|
||||
template <int tg_mem_size, typename in_T, typename out_T>
|
||||
[[kernel]] void bluestein_fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
const device float2* w_q [[buffer(2)]],
|
||||
const device float2* w_k [[buffer(3)]],
|
||||
constant const int& length,
|
||||
constant const int& n,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Computes arbitrary length FFTs with Bluestein's algorithm
|
||||
//
|
||||
// In numpy:
|
||||
// bluestein_n = next_power_of_2(2*n - 1)
|
||||
// out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q)
|
||||
//
|
||||
// Where w_k and w_q are precomputed on CPU in high precision as:
|
||||
// w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2))
|
||||
// w_q = np.fft.fft(1/w_k[-n:])
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
|
||||
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load_padded(length, w_k);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
int fft_idx = elem.z; // Thread index in DFT
|
||||
int m = grid.z; // Threads per DFT
|
||||
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
|
||||
threadgroup float2* buf = &shared_in[tg_idx];
|
||||
|
||||
// fft
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
float2 inv = float2(1.0f, -1.0f);
|
||||
for (int t = 0; t < elems_per_thread_; t++) {
|
||||
int index = fft_idx + t * m;
|
||||
buf[index] = complex_mul(buf[index], w_q[index]) * inv;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// ifft
|
||||
p = 1;
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write_padded(length, w_k);
|
||||
}
|
||||
|
||||
template <
|
||||
int tg_mem_size,
|
||||
typename in_T,
|
||||
typename out_T,
|
||||
int step,
|
||||
bool real = false>
|
||||
[[kernel]] void four_step_fft(
|
||||
const device in_T* in [[buffer(0)]],
|
||||
device out_T* out [[buffer(1)]],
|
||||
constant const int& n1,
|
||||
constant const int& n2,
|
||||
constant const int& batch_size,
|
||||
uint3 elem [[thread_position_in_grid]],
|
||||
uint3 grid [[threads_per_grid]]) {
|
||||
// Fast four step FFT implementation for powers of 2.
|
||||
int overall_n = n1 * n2;
|
||||
int n = step == 0 ? n1 : n2;
|
||||
int stride = step == 0 ? n2 : n1;
|
||||
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = grid.z;
|
||||
int fft_idx = elem.z;
|
||||
|
||||
threadgroup float2 shared_in[tg_mem_size];
|
||||
threadgroup float2* buf = &shared_in[elem.y * n];
|
||||
|
||||
using read_writer_t = ReadWriter<in_T, out_T, step, real>;
|
||||
read_writer_t read_writer = read_writer_t(
|
||||
in,
|
||||
&shared_in[0],
|
||||
out,
|
||||
n,
|
||||
batch_size,
|
||||
elems_per_thread_,
|
||||
elem,
|
||||
grid,
|
||||
inv_);
|
||||
|
||||
if (read_writer.out_of_bounds()) {
|
||||
return;
|
||||
};
|
||||
read_writer.load_strided(stride, overall_n);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
perform_fft(fft_idx, &p, m, n, buf);
|
||||
|
||||
read_writer.write_strided(stride, overall_n);
|
||||
}
|
@@ -1,67 +1,199 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
// Metal FFT using Stockham's algorithm
|
||||
//
|
||||
// References:
|
||||
// - VkFFT (https://github.com/DTolm/VkFFT)
|
||||
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_math>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/fft.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
#define instantiate_fft(tg_mem_size, in_T, out_T) \
|
||||
instantiate_kernel( \
|
||||
"fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
|
||||
fft, \
|
||||
tg_mem_size, \
|
||||
in_T, \
|
||||
out_T)
|
||||
using namespace metal;
|
||||
|
||||
#define instantiate_rader(tg_mem_size, in_T, out_T) \
|
||||
instantiate_kernel( \
|
||||
"rader_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
|
||||
rader_fft, \
|
||||
tg_mem_size, \
|
||||
in_T, \
|
||||
out_T)
|
||||
float2 complex_mul(float2 a, float2 b) {
|
||||
float2 c;
|
||||
c.x = a.x * b.x - a.y * b.y;
|
||||
c.y = a.x * b.y + a.y * b.x;
|
||||
return c;
|
||||
}
|
||||
|
||||
#define instantiate_bluestein(tg_mem_size, in_T, out_T) \
|
||||
instantiate_kernel( \
|
||||
"bluestein_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
|
||||
bluestein_fft, \
|
||||
tg_mem_size, \
|
||||
in_T, \
|
||||
out_T)
|
||||
float2 get_twiddle(int k, int p) {
|
||||
float theta = -1.0f * k * M_PI_F / (2 * p);
|
||||
|
||||
#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \
|
||||
instantiate_kernel( \
|
||||
"four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T "_" #step "_" #real, \
|
||||
four_step_fft, \
|
||||
tg_mem_size, \
|
||||
in_T, \
|
||||
out_T, \
|
||||
step, \
|
||||
real)
|
||||
float2 twiddle;
|
||||
twiddle.x = metal::fast::cos(theta);
|
||||
twiddle.y = metal::fast::sin(theta);
|
||||
return twiddle;
|
||||
}
|
||||
|
||||
// single threaded radix2 implemetation
|
||||
void radix2(
|
||||
int i,
|
||||
int p,
|
||||
int m,
|
||||
threadgroup float2* read_buf,
|
||||
threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
|
||||
float2 z = complex_mul(x_1, twiddle);
|
||||
|
||||
float2 y_0 = x_0 + z;
|
||||
float2 y_1 = x_0 - z;
|
||||
|
||||
int j = (i << 1) - k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
}
|
||||
|
||||
// single threaded radix4 implemetation
|
||||
void radix4(
|
||||
int i,
|
||||
int p,
|
||||
int m,
|
||||
threadgroup float2* read_buf,
|
||||
threadgroup float2* write_buf) {
|
||||
float2 x_0 = read_buf[i];
|
||||
float2 x_1 = read_buf[i + m];
|
||||
float2 x_2 = read_buf[i + 2 * m];
|
||||
float2 x_3 = read_buf[i + 3 * m];
|
||||
|
||||
// The index within this sub-DFT
|
||||
int k = i & (p - 1);
|
||||
|
||||
float2 twiddle = get_twiddle(k, p);
|
||||
// e^a * e^b = e^(a + b)
|
||||
float2 twiddle_2 = complex_mul(twiddle, twiddle);
|
||||
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
|
||||
|
||||
x_1 = complex_mul(x_1, twiddle);
|
||||
x_2 = complex_mul(x_2, twiddle_2);
|
||||
x_3 = complex_mul(x_3, twiddle_3);
|
||||
|
||||
float2 minus_i;
|
||||
minus_i.x = 0;
|
||||
minus_i.y = -1;
|
||||
|
||||
// Hard coded twiddle factors for DFT4
|
||||
float2 z_0 = x_0 + x_2;
|
||||
float2 z_1 = x_0 - x_2;
|
||||
float2 z_2 = x_1 + x_3;
|
||||
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
|
||||
|
||||
float2 y_0 = z_0 + z_2;
|
||||
float2 y_1 = z_1 + z_3;
|
||||
float2 y_2 = z_0 - z_2;
|
||||
float2 y_3 = z_1 - z_3;
|
||||
|
||||
int j = ((i - k) << 2) + k;
|
||||
|
||||
write_buf[j] = y_0;
|
||||
write_buf[j + p] = y_1;
|
||||
write_buf[j + 2 * p] = y_2;
|
||||
write_buf[j + 3 * p] = y_3;
|
||||
}
|
||||
|
||||
// Each FFT is computed entirely in shared GPU memory.
|
||||
//
|
||||
// N is decomposed into radix-2 and radix-4 DFTs:
|
||||
// e.g. 128 = 2 * 4 * 4 * 4
|
||||
//
|
||||
// At each step we use n / 4 threads, each performing
|
||||
// a single-threaded radix-4 or radix-2 DFT.
|
||||
//
|
||||
// We provide the number of radix-2 and radix-4
|
||||
// steps at compile time for a ~20% performance boost.
|
||||
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
|
||||
[[kernel]] void fft(
|
||||
const device float2* in [[buffer(0)]],
|
||||
device float2* out [[buffer(1)]],
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]],
|
||||
uint3 threads_per_grid [[threads_per_grid]]) {
|
||||
// Index of the DFT in batch
|
||||
int batch_idx = thread_position_in_grid.x * n;
|
||||
// The index in the DFT we're working on
|
||||
int i = thread_position_in_grid.y;
|
||||
// The number of the threads we're using for each DFT
|
||||
int m = threads_per_grid.y;
|
||||
|
||||
// Allocate 2 shared memory buffers for Stockham.
|
||||
// We alternate reading from one and writing to the other at each radix step.
|
||||
threadgroup float2 shared_in[n];
|
||||
threadgroup float2 shared_out[n];
|
||||
|
||||
// Pointers to facilitate Stockham buffer swapping
|
||||
threadgroup float2* read_buf = shared_in;
|
||||
threadgroup float2* write_buf = shared_out;
|
||||
threadgroup float2* tmp;
|
||||
|
||||
// Copy input into shared memory
|
||||
shared_in[i] = in[batch_idx + i];
|
||||
shared_in[i + m] = in[batch_idx + i + m];
|
||||
shared_in[i + 2 * m] = in[batch_idx + i + 2 * m];
|
||||
shared_in[i + 3 * m] = in[batch_idx + i + 3 * m];
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
int p = 1;
|
||||
|
||||
for (size_t r = 0; r < radix_2_steps; r++) {
|
||||
radix2(i, p, m * 2, read_buf, write_buf);
|
||||
radix2(i + m, p, m * 2, read_buf, write_buf);
|
||||
p *= 2;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
for (size_t r = 0; r < radix_4_steps; r++) {
|
||||
radix4(i, p, m, read_buf, write_buf);
|
||||
p *= 4;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Stockham switch of buffers
|
||||
tmp = write_buf;
|
||||
write_buf = read_buf;
|
||||
read_buf = tmp;
|
||||
}
|
||||
|
||||
// Copy shared memory to output
|
||||
out[batch_idx + i] = read_buf[i];
|
||||
out[batch_idx + i + m] = read_buf[i + m];
|
||||
out[batch_idx + i + 2 * m] = read_buf[i + 2 * m];
|
||||
out[batch_idx + i + 3 * m] = read_buf[i + 3 * m];
|
||||
}
|
||||
|
||||
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
|
||||
template [[host_name("fft_" #name)]] [[kernel]] void \
|
||||
fft<n, radix_2_steps, radix_4_steps>( \
|
||||
const device float2* in [[buffer(0)]], \
|
||||
device float2* out [[buffer(1)]], \
|
||||
uint3 thread_position_in_grid [[thread_position_in_grid]], \
|
||||
uint3 threads_per_grid [[threads_per_grid]]);
|
||||
|
||||
// Explicitly define kernels for each power of 2.
|
||||
// clang-format off
|
||||
#define instantiate_ffts(tg_mem_size) \
|
||||
instantiate_fft(tg_mem_size, float2, float2) \
|
||||
instantiate_fft(tg_mem_size, float, float2) \
|
||||
instantiate_fft(tg_mem_size, float2, float) \
|
||||
instantiate_rader(tg_mem_size, float2, float2) \
|
||||
instantiate_rader(tg_mem_size, float, float2) \
|
||||
instantiate_rader(tg_mem_size, float2, float) \
|
||||
instantiate_bluestein(tg_mem_size, float2, float2) \
|
||||
instantiate_bluestein(tg_mem_size, float, float2) \
|
||||
instantiate_bluestein(tg_mem_size, float2, float) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/false) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/false) \
|
||||
instantiate_four_step(tg_mem_size, float, float2, 0, /*real=*/true) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/true) \
|
||||
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/true) \
|
||||
instantiate_four_step(tg_mem_size, float2, float, 1, /*real=*/true)
|
||||
|
||||
// It's substantially faster to statically define the
|
||||
// threadgroup memory size rather than using
|
||||
// `setThreadgroupMemoryLength` on the compute encoder.
|
||||
// For non-power of 2 sizes we round up the shared memory.
|
||||
instantiate_ffts(256)
|
||||
instantiate_ffts(512)
|
||||
instantiate_ffts(1024)
|
||||
instantiate_ffts(2048)
|
||||
// 4096 is the max that will fit into 32KB of threadgroup memory.
|
||||
instantiate_ffts(4096) // clang-format on
|
||||
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
|
||||
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
|
||||
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
|
||||
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
|
||||
instantiate_fft(512, 512, 1, 4)
|
||||
instantiate_fft(1024, 1024, 0, 5)
|
||||
// 2048 is the max that will fit into 32KB of threadgroup memory.
|
||||
// TODO: implement 4 step FFT for larger n.
|
||||
instantiate_fft(2048, 2048, 1, 5) // clang-format on
|
||||
|
@@ -1,328 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
/* Radix kernels
|
||||
|
||||
We provide optimized, single threaded Radix codelets
|
||||
for n=2,3,4,5,6,7,8,10,11,12,13.
|
||||
|
||||
For n=2,3,4,5,6 we hand write the codelets.
|
||||
For n=8,10,12 we combine smaller codelets.
|
||||
For n=7,11,13 we use Rader's algorithm which decomposes
|
||||
them into (n-1)=6,10,12 codelets. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <metal_common>
|
||||
#include <metal_math>
|
||||
#include <metal_stdlib>
|
||||
|
||||
METAL_FUNC float2 complex_mul(float2 a, float2 b) {
|
||||
return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
|
||||
}
|
||||
|
||||
// Complex mul followed by conjugate
|
||||
METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {
|
||||
return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x);
|
||||
}
|
||||
|
||||
// Compute an FFT twiddle factor
|
||||
METAL_FUNC float2 get_twiddle(int k, int p) {
|
||||
float theta = -2.0f * k * M_PI_F / p;
|
||||
|
||||
float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)};
|
||||
return twiddle;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix2(thread float2* x, thread float2* y) {
|
||||
y[0] = x[0] + x[1];
|
||||
y[1] = x[0] - x[1];
|
||||
}
|
||||
|
||||
METAL_FUNC void radix3(thread float2* x, thread float2* y) {
|
||||
float pi_2_3 = -0.8660254037844387;
|
||||
|
||||
float2 a_1 = x[1] + x[2];
|
||||
float2 a_2 = x[1] - x[2];
|
||||
|
||||
y[0] = x[0] + a_1;
|
||||
float2 b_1 = x[0] - 0.5 * a_1;
|
||||
float2 b_2 = pi_2_3 * a_2;
|
||||
|
||||
float2 b_2_j = {-b_2.y, b_2.x};
|
||||
y[1] = b_1 + b_2_j;
|
||||
y[2] = b_1 - b_2_j;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix4(thread float2* x, thread float2* y) {
|
||||
float2 z_0 = x[0] + x[2];
|
||||
float2 z_1 = x[0] - x[2];
|
||||
float2 z_2 = x[1] + x[3];
|
||||
float2 z_3 = x[1] - x[3];
|
||||
float2 z_3_i = {z_3.y, -z_3.x};
|
||||
|
||||
y[0] = z_0 + z_2;
|
||||
y[1] = z_1 + z_3_i;
|
||||
y[2] = z_0 - z_2;
|
||||
y[3] = z_1 - z_3_i;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix5(thread float2* x, thread float2* y) {
|
||||
float2 root_5_4 = 0.5590169943749475;
|
||||
float2 sin_2pi_5 = 0.9510565162951535;
|
||||
float2 sin_1pi_5 = 0.5877852522924731;
|
||||
|
||||
float2 a_1 = x[1] + x[4];
|
||||
float2 a_2 = x[2] + x[3];
|
||||
float2 a_3 = x[1] - x[4];
|
||||
float2 a_4 = x[2] - x[3];
|
||||
|
||||
float2 a_5 = a_1 + a_2;
|
||||
float2 a_6 = root_5_4 * (a_1 - a_2);
|
||||
float2 a_7 = x[0] - a_5 / 4;
|
||||
float2 a_8 = a_7 + a_6;
|
||||
float2 a_9 = a_7 - a_6;
|
||||
float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4;
|
||||
float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4;
|
||||
float2 a_10_j = {a_10.y, -a_10.x};
|
||||
float2 a_11_j = {a_11.y, -a_11.x};
|
||||
|
||||
y[0] = x[0] + a_5;
|
||||
y[1] = a_8 + a_10_j;
|
||||
y[2] = a_9 + a_11_j;
|
||||
y[3] = a_9 - a_11_j;
|
||||
y[4] = a_8 - a_10_j;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix6(thread float2* x, thread float2* y) {
|
||||
float sin_pi_3 = 0.8660254037844387;
|
||||
float2 a_1 = x[2] + x[4];
|
||||
float2 a_2 = x[0] - a_1 / 2;
|
||||
float2 a_3 = sin_pi_3 * (x[2] - x[4]);
|
||||
float2 a_4 = x[5] + x[1];
|
||||
float2 a_5 = x[3] - a_4 / 2;
|
||||
float2 a_6 = sin_pi_3 * (x[5] - x[1]);
|
||||
float2 a_7 = x[0] + a_1;
|
||||
|
||||
float2 a_3_i = {a_3.y, -a_3.x};
|
||||
float2 a_6_i = {a_6.y, -a_6.x};
|
||||
float2 a_8 = a_2 + a_3_i;
|
||||
float2 a_9 = a_2 - a_3_i;
|
||||
float2 a_10 = x[3] + a_4;
|
||||
float2 a_11 = a_5 + a_6_i;
|
||||
float2 a_12 = a_5 - a_6_i;
|
||||
|
||||
y[0] = a_7 + a_10;
|
||||
y[1] = a_8 - a_11;
|
||||
y[2] = a_9 + a_12;
|
||||
y[3] = a_7 - a_10;
|
||||
y[4] = a_8 + a_11;
|
||||
y[5] = a_9 - a_12;
|
||||
}
|
||||
|
||||
METAL_FUNC void radix7(thread float2* x, thread float2* y) {
|
||||
// Rader's algorithm
|
||||
float2 inv = {1 / 6.0, -1 / 6.0};
|
||||
|
||||
// fft
|
||||
float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]};
|
||||
radix6(in1, y + 1);
|
||||
|
||||
y[0] = y[1] + x[0];
|
||||
|
||||
// b_q
|
||||
y[1] = complex_mul_conj(y[1], float2(-1, 0));
|
||||
y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879));
|
||||
y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629));
|
||||
y[4] = complex_mul_conj(y[4], float2(0, -2.64575131));
|
||||
y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629));
|
||||
y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879));
|
||||
|
||||
// ifft
|
||||
radix6(y + 1, x + 1);
|
||||
|
||||
y[1] = x[1] * inv + x[0];
|
||||
y[5] = x[2] * inv + x[0];
|
||||
y[4] = x[3] * inv + x[0];
|
||||
y[6] = x[4] * inv + x[0];
|
||||
y[2] = x[5] * inv + x[0];
|
||||
y[3] = x[6] * inv + x[0];
|
||||
}
|
||||
|
||||
METAL_FUNC void radix8(thread float2* x, thread float2* y) {
|
||||
float cos_pi_4 = 0.7071067811865476;
|
||||
float2 w_0 = {cos_pi_4, -cos_pi_4};
|
||||
float2 w_1 = {-cos_pi_4, -cos_pi_4};
|
||||
float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]};
|
||||
radix4(temp, x);
|
||||
radix4(temp + 4, x + 4);
|
||||
|
||||
y[0] = x[0] + x[4];
|
||||
y[4] = x[0] - x[4];
|
||||
float2 x_5 = complex_mul(x[5], w_0);
|
||||
y[1] = x[1] + x_5;
|
||||
y[5] = x[1] - x_5;
|
||||
float2 x_6 = {x[6].y, -x[6].x};
|
||||
y[2] = x[2] + x_6;
|
||||
y[6] = x[2] - x_6;
|
||||
float2 x_7 = complex_mul(x[7], w_1);
|
||||
y[3] = x[3] + x_7;
|
||||
y[7] = x[3] - x_7;
|
||||
}
|
||||
|
||||
template <bool raders_perm>
|
||||
METAL_FUNC void radix10(thread float2* x, thread float2* y) {
|
||||
float2 w[4];
|
||||
w[0] = {0.8090169943749475, -0.5877852522924731};
|
||||
w[1] = {0.30901699437494745, -0.9510565162951535};
|
||||
w[2] = {-w[1].x, w[1].y};
|
||||
w[3] = {-w[0].x, w[0].y};
|
||||
|
||||
if (raders_perm) {
|
||||
float2 temp[10] = {
|
||||
x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]};
|
||||
radix5(temp, x);
|
||||
radix5(temp + 5, x + 5);
|
||||
} else {
|
||||
float2 temp[10] = {
|
||||
x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]};
|
||||
radix5(temp, x);
|
||||
radix5(temp + 5, x + 5);
|
||||
}
|
||||
|
||||
y[0] = x[0] + x[5];
|
||||
y[5] = x[0] - x[5];
|
||||
for (int t = 1; t < 5; t++) {
|
||||
float2 a = complex_mul(x[t + 5], w[t - 1]);
|
||||
y[t] = x[t] + a;
|
||||
y[t + 5] = x[t] - a;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void radix11(thread float2* x, thread float2* y) {
|
||||
// Raders Algorithm
|
||||
float2 inv = {1 / 10.0, -1 / 10.0};
|
||||
|
||||
// fft
|
||||
radix10<true>(x + 1, y + 1);
|
||||
|
||||
y[0] = y[1] + x[0];
|
||||
|
||||
// b_q
|
||||
y[1] = complex_mul_conj(y[1], float2(-1, 0));
|
||||
y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649));
|
||||
y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656));
|
||||
y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479));
|
||||
y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150));
|
||||
y[6] = complex_mul_conj(y[6], float2(0, -3.31662479));
|
||||
y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150));
|
||||
y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479));
|
||||
y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656));
|
||||
y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649));
|
||||
|
||||
// ifft
|
||||
radix10<false>(y + 1, x + 1);
|
||||
|
||||
y[1] = x[1] * inv + x[0];
|
||||
y[6] = x[2] * inv + x[0];
|
||||
y[3] = x[3] * inv + x[0];
|
||||
y[7] = x[4] * inv + x[0];
|
||||
y[9] = x[5] * inv + x[0];
|
||||
y[10] = x[6] * inv + x[0];
|
||||
y[5] = x[7] * inv + x[0];
|
||||
y[8] = x[8] * inv + x[0];
|
||||
y[4] = x[9] * inv + x[0];
|
||||
y[2] = x[10] * inv + x[0];
|
||||
}
|
||||
|
||||
template <bool raders_perm>
|
||||
METAL_FUNC void radix12(thread float2* x, thread float2* y) {
|
||||
float2 w[6];
|
||||
float sin_pi_3 = 0.8660254037844387;
|
||||
w[0] = {sin_pi_3, -0.5};
|
||||
w[1] = {0.5, -sin_pi_3};
|
||||
w[2] = {0, -1};
|
||||
w[3] = {-0.5, -sin_pi_3};
|
||||
w[4] = {-sin_pi_3, -0.5};
|
||||
|
||||
if (raders_perm) {
|
||||
float2 temp[12] = {
|
||||
x[0],
|
||||
x[3],
|
||||
x[2],
|
||||
x[11],
|
||||
x[8],
|
||||
x[9],
|
||||
x[1],
|
||||
x[7],
|
||||
x[5],
|
||||
x[10],
|
||||
x[4],
|
||||
x[6]};
|
||||
radix6(temp, x);
|
||||
radix6(temp + 6, x + 6);
|
||||
} else {
|
||||
float2 temp[12] = {
|
||||
x[0],
|
||||
x[2],
|
||||
x[4],
|
||||
x[6],
|
||||
x[8],
|
||||
x[10],
|
||||
x[1],
|
||||
x[3],
|
||||
x[5],
|
||||
x[7],
|
||||
x[9],
|
||||
x[11]};
|
||||
radix6(temp, x);
|
||||
radix6(temp + 6, x + 6);
|
||||
}
|
||||
|
||||
y[0] = x[0] + x[6];
|
||||
y[6] = x[0] - x[6];
|
||||
for (int t = 1; t < 6; t++) {
|
||||
float2 a = complex_mul(x[t + 6], w[t - 1]);
|
||||
y[t] = x[t] + a;
|
||||
y[t + 6] = x[t] - a;
|
||||
}
|
||||
}
|
||||
|
||||
METAL_FUNC void radix13(thread float2* x, thread float2* y) {
|
||||
// Raders Algorithm
|
||||
float2 inv = {1 / 12.0, -1 / 12.0};
|
||||
|
||||
// fft
|
||||
radix12<true>(x + 1, y + 1);
|
||||
|
||||
y[0] = y[1] + x[0];
|
||||
|
||||
// b_q
|
||||
y[1] = complex_mul_conj(y[1], float2(-1, 0));
|
||||
y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669));
|
||||
y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823));
|
||||
y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161));
|
||||
y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690));
|
||||
y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267));
|
||||
y[7] = complex_mul_conj(y[7], float2(3.60555128, 0));
|
||||
y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267));
|
||||
y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690));
|
||||
y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161));
|
||||
y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823));
|
||||
y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669));
|
||||
|
||||
// ifft
|
||||
radix12<false>(y + 1, x + 1);
|
||||
|
||||
y[1] = x[1] * inv + x[0];
|
||||
y[7] = x[2] * inv + x[0];
|
||||
y[10] = x[3] * inv + x[0];
|
||||
y[5] = x[4] * inv + x[0];
|
||||
y[9] = x[5] * inv + x[0];
|
||||
y[11] = x[6] * inv + x[0];
|
||||
y[12] = x[7] * inv + x[0];
|
||||
y[6] = x[8] * inv + x[0];
|
||||
y[3] = x[9] * inv + x[0];
|
||||
y[8] = x[10] * inv + x[0];
|
||||
y[4] = x[11] * inv + x[0];
|
||||
y[2] = x[12] * inv + x[0];
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user