Compare commits

..

1 Commits

Author SHA1 Message Date
Angelos Katharopoulos
3a1df968cf Add async all reduce and donation 2024-05-29 22:54:32 -07:00
231 changed files with 6618 additions and 15793 deletions

View File

@@ -144,7 +144,6 @@ jobs:
name: Install dependencies
command: |
brew install python@<< parameters.python_version >>
brew install openmpi
python<< parameters.python_version >> -m venv env
source env/bin/activate
pip install --upgrade pip

View File

@@ -17,4 +17,4 @@ jobs:
pip install pre-commit black isort clang-format
- name: Run lint
run: |
pre-commit run --all-files
pre-commit run --all-files

View File

@@ -10,14 +10,13 @@ MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
if(NOT MLX_VERSION)
set(MLX_VERSION 0.16.3)
set(MLX_VERSION 0.14.0)
endif()
# --------------------- Processor tests -------------------------
@@ -83,21 +83,24 @@ elseif (MLX_BUILD_METAL)
OUTPUT_VARIABLE MACOS_VERSION
COMMAND_ERROR_IS_FATAL ANY)
if (${MACOS_VERSION} LESS 14.0)
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
message(STATUS "Building with SDK for macOS version ${MACOS_VERSION}")
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18-beta.zip)
# Get the metal version
execute_process(
COMMAND zsh "-c" "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal -E -x metal -P - | tail -1 | tr -d '\n'"
OUTPUT_VARIABLE MLX_METAL_VERSION
COMMAND_ERROR_IS_FATAL ANY)
if (${MACOS_VERSION} GREATER_EQUAL 14.2)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.2.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14.2_iOS17.2.zip)
set(MLX_METAL_VERSION METAL_3_1)
elseif (${MACOS_VERSION} GREATER_EQUAL 14.0)
set(METAL_CPP_PATCH ${CMAKE_CURRENT_SOURCE_DIR}/cmake/metal.14.0.diff)
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
set(MLX_METAL_VERSION METAL_3_0)
else()
message(FATAL_ERROR "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON" )
endif()
FetchContent_Declare(
metal_cpp
URL ${METAL_CPP_URL}
PATCH_COMMAND /usr/bin/patch -N -i ${METAL_CPP_PATCH} || true
)
FetchContent_MakeAvailable(metal_cpp)
@@ -112,7 +115,7 @@ elseif (MLX_BUILD_METAL)
${FOUNDATION_LIB}
${QUARTZ_LIB})
add_compile_definitions("MLX_METAL_VERSION=${MLX_METAL_VERSION}")
add_compile_definitions(${MLX_METAL_VERSION})
endif()
if (MLX_BUILD_CPU)
@@ -166,26 +169,7 @@ endif()
find_package(MPI)
if (MPI_FOUND)
execute_process(
COMMAND zsh "-c" "mpirun --version"
OUTPUT_VARIABLE MPI_VERSION
ERROR_QUIET
)
if (${MPI_VERSION} MATCHES ".*Open MPI.*")
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
elseif (MPI_VERSION STREQUAL "")
set(MPI_FOUND FALSE)
message(
WARNING
"MPI found but mpirun is not available. Building without MPI."
)
else()
set(MPI_FOUND FALSE)
message(
WARNING
"MPI which is not OpenMPI found. Building without MPI."
)
endif()
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)

View File

@@ -185,7 +185,7 @@ def prelu(x: torch.Tensor) -> torch.Tensor:
def mish(x: torch.Tensor) -> torch.Tensor:
y = x
for _ in range(100):
y = torch.nn.functional.mish(y)
return torch.nn.functional.mish(y)
sync_if_needed(x)
@@ -283,14 +283,6 @@ def topk(axis, x):
sync_if_needed(x)
@torch.no_grad()
def step_function(x):
y = x
for i in range(100):
y = torch.where(y < 0, 0, 1)
sync_if_needed(x)
@torch.no_grad()
def selu(x):
y = x
@@ -454,11 +446,5 @@ if __name__ == "__main__":
elif args.benchmark == "topk":
print(bench(topk, axis, x))
elif args.benchmark == "step":
print(bench(step_function, x))
elif args.benchmark == "selu":
print(bench(selu, x))
else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
raise ValueError("Unknown benchmark")

View File

@@ -16,9 +16,7 @@ def run_or_raise(*args, **kwargs):
result = run(*args, capture_output=True, **kwargs)
return float(result.stdout)
except ValueError:
raise ValueError(
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
)
raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}")
def compare(args):

View File

@@ -9,6 +9,7 @@ from time_utils import time_fn
def bench_gelu():
def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
@@ -50,6 +51,7 @@ def bench_gelu():
def bench_layernorm():
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
mx.eval(weight, bias)

View File

@@ -54,6 +54,7 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(

View File

@@ -1,84 +0,0 @@
# Copyright © 2024 Apple Inc.
import time
import mlx.core as mx
import numpy as np
def timeit(fn, its=100, args=[]):
for _ in range(5):
fn(*args)
tic = time.perf_counter()
for _ in range(its):
fn(*args)
toc = time.perf_counter()
return 1e3 * (toc - tic) / its
def time_little_einsum_path():
subscripts = "ik,kj->ij"
x = mx.ones((32, 32))
y = mx.ones((32, 32))
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
x = np.array(x)
y = np.array(y)
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
print("Timing little einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_big_einsum_path():
chars = list("abcdefgh")
char_to_dim = {c: v for v, c in enumerate(chars)}
num_inputs = 10
inputs = []
subscripts = []
for _ in range(num_inputs):
subscript = np.random.choice(chars, size=5, replace=False).tolist()
subscripts.append("".join(subscript))
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
subscripts = ",".join(subscripts)
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
inputs = [mx.array(x) for x in inputs]
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
print("Timing big einsum path...")
print(f"MLX ... {mx_time:.3f} ms")
print(f"NumPy... {np_time:.3f} ms")
def time_attention():
def regular_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
scores = mx.softmax(scores, axis=-1)
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
mx.eval(output)
def einsum_attention(x):
# shape [batch, sequence, num_heads, head_dim]
queries, keys, values = x, x, x
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
scores = mx.softmax(scores, axis=-1)
output = mx.einsum("ijtu,iujk->itjk", scores, values)
mx.eval(output)
x = mx.random.uniform(shape=(8, 512, 32, 128))
regular_time = timeit(regular_attention, args=(x,))
ein_time = timeit(einsum_attention, args=(x,))
print("Timing einsum attention...")
print(f"Regular ... {regular_time:.3f} ms")
print(f"Einsum ... {ein_time:.3f} ms")
if __name__ == "__main__":
time_little_einsum_path()
time_big_einsum_path()
time_attention()

View File

@@ -3,8 +3,6 @@
import matplotlib
import mlx.core as mx
import numpy as np
import sympy
import torch
from time_utils import measure_runtime
matplotlib.use("Agg")
@@ -18,100 +16,41 @@ def bandwidth_gb(runtime_ms, system_size):
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
def fft_mlx(x):
if dim == 1:
out = mx.fft.fft(x)
elif dim == 2:
out = mx.fft.fft2(x)
def run_bench(system_size):
def fft(x):
out = mx.fft.fft(x)
mx.eval(out)
return out
def fft_mps(x):
if dim == 1:
out = torch.fft.fft(x)
elif dim == 2:
out = torch.fft.fft2(x)
torch.mps.synchronize()
return out
bandwidths = []
for n in fft_sizes:
batch_size = system_size // n**dim
shape = [batch_size] + [n for _ in range(dim)]
if backend == "mlx":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = mx.array(x_np)
mx.eval(x)
fft = fft_mlx
elif backend == "mps":
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
x = torch.tensor(x_np, device="mps")
torch.mps.synchronize()
fft = fft_mps
else:
raise NotImplementedError()
for k in range(4, 12):
n = 2**k
x = mx.random.uniform(shape=(system_size // n, n)).astype(mx.float32)
x = x.astype(mx.complex64)
mx.eval(x)
runtime_ms = measure_runtime(fft, x=x)
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
print(n, bandwidth)
bandwidths.append(bandwidth)
bandwidths.append(bandwidth_gb(runtime_ms, system_size))
return np.array(bandwidths)
return bandwidths
def time_fft():
x = np.array(range(2, 512))
system_size = int(2**26)
print("MLX GPU")
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
print("MPS GPU")
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
print("CPU")
system_size = int(2**20)
with mx.stream(mx.cpu):
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
cpu_bandwidths = run_bench(system_size=int(2**22))
x = np.array(x)
with mx.stream(mx.gpu):
gpu_bandwidths = run_bench(system_size=int(2**29))
all_indices = x - x[0]
radix_2to13 = (
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
)
bluesteins = (
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
)
for indices, name in [
(all_indices, "All"),
(radix_2to13, "Radix 2-13"),
(bluesteins, "Bluestein's"),
]:
# plot bandwidths
print(name)
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
plt.title(f"MLX FFT Benchmark -- {name}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"{name}.png")
plt.clf()
av_gpu_bandwidth = np.mean(gpu_bandwidths)
av_mps_bandwidth = np.mean(mps_bandwidths)
av_cpu_bandwidth = np.mean(cpu_bandwidths)
print("Average bandwidths:")
print("GPU:", av_gpu_bandwidth)
print("MPS:", av_mps_bandwidth)
print("CPU:", av_cpu_bandwidth)
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
print("Percent MLX faster than MPS: ", portion_faster * 100)
# plot bandwidths
x = [2**k for k in range(4, 12)]
plt.scatter(x, gpu_bandwidths, color="green", label="GPU")
plt.scatter(x, cpu_bandwidths, color="red", label="CPU")
plt.title("MLX FFT Benchmark")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig("fft_plot.png")
if __name__ == "__main__":

View File

@@ -1,70 +0,0 @@
import argparse
import matplotlib
import mlx.core as mx
import numpy as np
from time_utils import measure_runtime
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def had(x):
y = mx.hadamard_transform(x)
mx.eval(y)
def copy(x):
y = x + 1.0
mx.eval(y)
def run(dtype):
system_size = 2**26
outputs = {}
for test_fn in (had, copy):
for m in [1, 12, 20, 28]:
if test_fn == copy:
key = "copy"
elif m == 1:
key = "had_2^k"
else:
key = "had_m*2^k"
outputs.setdefault(key, {})
for k in range(7, 14):
n = m * 2**k
if n > 2**15:
continue
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
x = mx.array(x_np)
runtime_ms = measure_runtime(test_fn, x=x)
bytes_per_gb = 1e9
ms_per_s = 1e3
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
bandwidth_gb = (
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
)
print(n, bandwidth_gb)
outputs[key][n] = bandwidth_gb
colors = {
"copy": "black",
"had_2^k": "steelblue",
"had_m*2^k": "skyblue",
}
for key, output in outputs.items():
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
plt.xlabel("N")
plt.ylabel("Bandwidth (GB/s)")
plt.legend()
plt.savefig(f"bench_{dtype.__name__}.png")
plt.clf()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
dtype = np.float16 if args.fp16 else np.float32
run(dtype)

View File

@@ -1,62 +0,0 @@
import argparse
import math
import mlx.core as mx
from time_utils import time_fn
MAX_SEQ = 300
START_SEQ = 100
SEQ_INCREMENT = 50
def time_self_attention_primitives():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_primitives(qs, ks, vs, alpha):
s = (alpha * qs) @ ks.transpose(0, 1, 3, 2)
p = mx.softmax(s.astype(mx.float32), axis=-1).astype(s.dtype)
o = p @ vs
return o
time_fn(sdpa_primitives, q, k, v, scale)
def time_self_attention_sdpa():
mx.random.seed(3)
B = 2
H = 38
D = 64
for R in range(START_SEQ, MAX_SEQ, SEQ_INCREMENT):
q = mx.random.uniform(shape=(B, H, R, D))
k = mx.random.uniform(shape=(B, H, R, D))
v = mx.random.uniform(shape=(B, H, R, D))
scale = 1.0 / math.sqrt(float(D))
mx.eval(q, k, v)
def sdpa_fused(qs, ks, vs, alpha):
o = mx.fast.scaled_dot_product_attention(qs, ks, vs, scale=alpha)
return o
time_fn(sdpa_fused, q, k, v, scale)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MLX benchmarks.")
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
args = parser.parse_args()
if args.gpu:
mx.set_default_device(mx.gpu)
else:
mx.set_default_device(mx.cpu)
time_self_attention_sdpa()
time_self_attention_primitives()

36
cmake/metal.14.0.diff Normal file
View File

@@ -0,0 +1,36 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:36:59
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2023-06-01 12:18:26
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:37:29
@@ -1906,6 +1906,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

36
cmake/metal.14.2.diff Normal file
View File

@@ -0,0 +1,36 @@
diff -ur Metal/MTLEvent.hpp MetalNew/MTLEvent.hpp
--- Metal/MTLEvent.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLEvent.hpp 2024-04-15 07:15:50
@@ -62,6 +62,7 @@
uint64_t signaledValue() const;
void setSignaledValue(uint64_t signaledValue);
+ bool waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS);
};
class SharedEventHandle : public NS::SecureCoding<SharedEventHandle>
@@ -138,6 +139,11 @@
_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue);
+}
+
+// method: waitUntilSignaledValue
+_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t signaledValue, uint64_t timeoutMS) {
+ return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), signaledValue, timeoutMS);
}
// static method: alloc
diff -ur Metal/MTLHeaderBridge.hpp MetalNew/MTLHeaderBridge.hpp
--- Metal/MTLHeaderBridge.hpp 2024-04-15 07:12:10
+++ MetalNew/MTLHeaderBridge.hpp 2024-04-15 07:16:15
@@ -1918,6 +1918,9 @@
"setShouldMaximizeConcurrentCompilation:");
_MTL_PRIVATE_DEF_SEL(setSignaledValue_,
"setSignaledValue:");
+_MTL_PRIVATE_DEF_SEL(
+ waitUntilSignaledValue_timeoutMS_,
+ "waitUntilSignaledValue:timeoutMS:");
_MTL_PRIVATE_DEF_SEL(setSize_,
"setSize:");
_MTL_PRIVATE_DEF_SEL(setSlice_,

View File

@@ -1,4 +1,3 @@
sphinx
breathe
sphinx-book-theme
mlx

View File

@@ -83,15 +83,3 @@ def setup(app):
# -- Options for LaTeX output ------------------------------------------------
latex_documents = [(main_doc, "MLX.tex", "MLX Documentation", author, "manual")]
latex_elements = {
"preamble": r"""
\usepackage{enumitem}
\setlistdepth{5}
\setlist[itemize,1]{label=$\bullet$}
\setlist[itemize,2]{label=$\bullet$}
\setlist[itemize,3]{label=$\bullet$}
\setlist[itemize,4]{label=$\bullet$}
\setlist[itemize,5]{label=$\bullet$}
\renewlist{itemize}{itemize}{5}
""",
}

View File

@@ -486,8 +486,9 @@ below.
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");

View File

@@ -15,7 +15,7 @@ module to concisely define the model architecture.
Attention layer
^^^^^^^^^^^^^^^^
We will start with the Llama attention layer which notably uses the RoPE
We will start with the llama attention layer which notably uses the RoPE
positional encoding. [1]_ In addition, our attention layer will optionally use a
key/value cache that will be concatenated with the provided keys and values to
support efficient inference.

View File

@@ -64,7 +64,7 @@ set:
Next, setup the problem parameters and load the data. To load the data, you need our
`mnist data loader
<https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py>`_, which
we will import as ``mnist``.
we will import as `mnist`.
.. code-block:: python

View File

@@ -43,7 +43,6 @@ are the CPU and GPU.
usage/function_transforms
usage/compile
usage/numpy
usage/distributed
usage/using_streams
.. toctree::
@@ -70,7 +69,6 @@ are the CPU and GPU.
python/metal
python/nn
python/optimizers
python/distributed
python/tree_utils
.. toctree::

View File

@@ -70,36 +70,36 @@ To build and install the MLX python library from source, first, clone MLX from
git clone git@github.com:ml-explore/mlx.git mlx && cd mlx
Install `nanobind <https://nanobind.readthedocs.io/en/latest/>`_ with:
.. code-block:: shell
pip install git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4
Then simply build and install MLX using pip:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
For developing, install the package with development dependencies, and use an
editable install:
For developing use an editable install:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e ".[dev]"
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
Once the development dependencies are installed, you can build faster with:
.. code-block:: shell
CMAKE_BUILD_PARALLEL_LEVEL="" python setup.py build_ext -j --inplace
Run the tests with:
To make sure the install is working run the tests with:
.. code-block:: shell
pip install ".[testing]"
python -m unittest discover python/tests
Optional: Install stubs to enable auto completions and type checking from your
IDE:
Optional: Install stubs to enable auto completions and type checking from your IDE:
.. code-block:: shell
pip install ".[dev]"
python setup.py generate_stubs
C++ API
@@ -186,8 +186,8 @@ should point to the path to the built metal library.
Binary Size Minimization
~~~~~~~~~~~~~~~~~~~~~~~~
To produce a smaller binary use the CMake flags ``CMAKE_BUILD_TYPE=MinSizeRel``
and ``BUILD_SHARED_LIBS=ON``.
To produce a smaller binary use the CMake flags `CMAKE_BUILD_TYPE=MinSizeRel`
and `BUILD_SHARED_LIBS=ON`.
The MLX CMake build has several additional options to make smaller binaries.
For example, if you don't need the CPU backend or support for safetensors and
@@ -195,7 +195,7 @@ GGUF, you can do:
.. code-block:: shell
cmake .. \
cmake ..
-DCMAKE_BUILD_TYPE=MinSizeRel \
-DBUILD_SHARED_LIBS=ON \
-DMLX_BUILD_CPU=OFF \
@@ -203,7 +203,7 @@ GGUF, you can do:
-DMLX_BUILD_GGUF=OFF \
-DMLX_METAL_JIT=ON
THE ``MLX_METAL_JIT`` flag minimizes the size of the MLX Metal library which
THE `MLX_METAL_JIT` flag minimizes the size of the MLX Metal library which
contains pre-built GPU kernels. This substantially reduces the size of the
Metal library by run-time compiling kernels the first time they are used in MLX
on a given machine. Note run-time compilation incurs a cold-start cost which can

View File

@@ -24,7 +24,6 @@ Array
array.any
array.argmax
array.argmin
array.conj
array.cos
array.cummax
array.cummin
@@ -58,4 +57,3 @@ Array
array.transpose
array.T
array.var
array.view

View File

@@ -1,19 +0,0 @@
.. _distributed:
.. currentmodule:: mlx.core.distributed
Distributed Communication
==========================
MLX provides a distributed communication package using MPI. The MPI library is
loaded at runtime; if MPI is available then distributed communication is also
made available.
.. autosummary::
:toctree: _autosummary
Group
is_available
init
all_sum
all_gather

View File

@@ -9,9 +9,7 @@ Linear Algebra
:toctree: _autosummary
inv
tri_inv
norm
cholesky
cholesky_inv
qr
svd

View File

@@ -17,8 +17,6 @@ simple functions.
gelu_approx
gelu_fast_approx
glu
hard_shrink
hard_tanh
hardswish
leaky_relu
log_sigmoid
@@ -31,7 +29,6 @@ simple functions.
sigmoid
silu
softmax
softmin
softplus
softshrink
step

View File

@@ -21,15 +21,10 @@ Layers
Dropout3d
Embedding
GELU
GLU
GroupNorm
GRU
HardShrink
HardTanh
Hardswish
InstanceNorm
LayerNorm
LeakyReLU
Linear
LSTM
MaxPool1d
@@ -41,19 +36,13 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU6
RNN
RoPE
SELU
Sequential
SiLU
SinusoidalPositionalEncoding
Softmin
Softshrink
Softsign
Softmax
Softplus
Step
Tanh
Transformer
Upsample

View File

@@ -57,8 +57,6 @@ Operations
diagonal
divide
divmod
einsum
einsum_path
equal
erf
erfinv
@@ -74,7 +72,6 @@ Operations
gather_qmm
greater
greater_equal
hadamard_transform
identity
inner
isclose
@@ -106,7 +103,6 @@ Operations
minimum
moveaxis
multiply
nan_to_num
negative
not_equal
ones
@@ -160,7 +156,6 @@ Operations
tril
triu
var
view
where
zeros
zeros_like

View File

@@ -31,41 +31,6 @@ model's parameters and the **optimizer state**.
# Compute the new parameters but also the optimizer state.
mx.eval(model.parameters(), optimizer.state)
Saving and Loading
------------------
To serialize an optimizer, save its state. To load an optimizer, load and set
the saved state. Here's a simple example:
.. code-block:: python
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim
optimizer = optim.Adam(learning_rate=1e-2)
# Perform some updates with the optimizer
model = {"w" : mx.zeros((5, 5))}
grads = {"w" : mx.ones((5, 5))}
optimizer.update(model, grads)
# Save the state
state = tree_flatten(optimizer.state)
mx.save_safetensors("optimizer.safetensors", dict(state))
# Later on, for example when loading from a checkpoint,
# recreate the optimizer and load the state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(list(mx.load("optimizer.safetensors").items()))
optimizer.state = state
Note, not every optimizer configuation parameter is saved in the state. For
example, for Adam the learning rate is saved but the ``betas`` and ``eps``
parameters are not. A good rule of thumb is if the parameter can be scheduled
then it will be included in the optimizer state.
.. toctree::
optimizers/optimizer

View File

@@ -44,4 +44,3 @@ we use a splittable version of Threefry, which is a counter-based PRNG.
split
truncated_normal
uniform
laplace

View File

@@ -10,7 +10,6 @@ Transforms
eval
compile
custom_function
disable_compile
enable_compile
grad

View File

@@ -1,166 +0,0 @@
.. _usage_distributed:
Distributed Communication
=========================
.. currentmodule:: mlx.core.distributed
MLX utilizes `MPI <https://en.wikipedia.org/wiki/Message_Passing_Interface>`_ to
provide distributed communication operations that allow the computational cost
of training or inference to be shared across many physical machines. You can
see a list of the supported operations in the :ref:`API docs<distributed>`.
.. note::
A lot of operations may not be supported or not as fast as they should be.
We are adding more and tuning the ones we have as we are figuring out the
best way to do distributed computing on Macs using MLX.
Getting Started
---------------
MLX already comes with the ability to "talk" to MPI if it is installed on the
machine. The minimal distributed program in MLX is as simple as:
.. code:: python
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
The program above sums the array ``mx.ones(10)`` across all
distributed processes. If simply run with ``python``, however, only one
process is launched and no distributed communication takes place.
To launch the program in distributed mode we need to use ``mpirun`` or
``mpiexec`` depending on the MPI installation. The simplest possible way is the
following:
.. code:: shell
$ mpirun -np 2 python test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
The above launches two processes on the same (local) machine and we can see
both standard output streams. The processes send the array of 1s to each other
and compute the sum which is printed. Launching with ``mpirun -np 4 ...`` would
print 4 etc.
Installing MPI
---------------
MPI can be installed with Homebrew, using the Anaconda package manager or
compiled from source. Most of our testing is done using ``openmpi`` installed
with the Anaconda package manager as follows:
.. code:: shell
$ conda install openmpi
Installing with Homebrew may require specifying the location of ``libmpi.dyld``
so that MLX can find it and load it at runtime. This can simply be achieved by
passing the ``DYLD_LIBRARY_PATH`` environment variable to ``mpirun``.
.. code:: shell
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
Setting up Remote Hosts
-----------------------
MPI can automatically connect to remote hosts and set up the communication over
the network if the remote hosts can be accessed via ssh. A good checklist to
debug connectivity issues is the following:
* ``ssh hostname`` works from all machines to all machines without asking for
password or host confirmation
* ``mpirun`` is accessible on all machines. You can call ``mpirun`` using its
full path to force all machines to use a specific path.
* Ensure that the ``hostname`` used by MPI is the one that you have configured
in the ``.ssh/config`` files on all machines.
.. note::
For an example hostname ``foo.bar.com`` MPI can use only ``foo`` as
the hostname passed to ssh if the current hostname matches ``*.bar.com``.
An easy way to pass the host names to MPI is using a host file. A host file
looks like the following, where ``host1`` and ``host2`` should be the fully
qualified domain names or IPs for these hosts.
.. code::
host1 slots=1
host2 slots=1
When using MLX, it is very likely that you want to use 1 slot per host, ie one
process per host. The hostfile also needs to contain the current
host if you want to run on the local host. Passing the host file to
``mpirun`` is simply done using the ``--hostfile`` command line argument.
Training Example
----------------
In this section we will adapt an MLX training loop to support data parallel
distributed training. Namely, we will average the gradients across a set of
hosts before applying them to the model.
Our training loop looks like the following code snippet if we omit the model,
dataset and optimizer initialization.
.. code:: python
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
All we have to do to average the gradients across machines is perform an
:func:`all_sum` and divide by the size of the :class:`Group`. Namely we
have to :func:`mlx.utils.tree_map` the gradients with following function.
.. code:: python
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
Putting everything together our training loop step looks as follows with
everything else remaining the same.
.. code:: python
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
Tuning All Reduce
-----------------
We are working on improving the performance of all reduce on MLX but for now
the two main things one can do to extract the most out of distributed training with MLX are:
1. Perform a few large reductions instead of many small ones to improve
bandwidth and latency
2. Pass ``--mca btl_tcp_links 4`` to ``mpirun`` to configure it to use 4 tcp
connections between each host to improve bandwidth

View File

@@ -3,11 +3,7 @@
Conversion to NumPy and Other Frameworks
========================================
MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
Let's convert an array to NumPy and back.
.. code-block:: python

View File

@@ -16,7 +16,7 @@ int main() {
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10});
array out = distributed::all_sum(x, global_group);
array out = distributed::all_reduce_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@@ -249,8 +249,9 @@ void Axpby::eval_gpu(
kname << (contiguous_kernel ? "contiguous_" : "general_");
kname << type_to_name(out);
// Make sure the metal library is available
d.register_library("mlx_ext");
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");

View File

@@ -1,4 +1,4 @@
setuptools>=42
cmake>=3.24
mlx>=0.16.2
nanobind==2.0
mlx>=0.9.0
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4

View File

@@ -6,7 +6,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/compile.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dtype.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp

View File

@@ -17,10 +17,6 @@ bool in_tracing() {
return detail::InTracing::in_tracing();
}
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
@@ -106,7 +102,7 @@ void array::eval() {
}
bool array::is_tracer() const {
return array_desc_->is_tracer && in_tracing() || retain_graph();
return array_desc_->is_tracer && in_tracing();
}
void array::set_data(allocator::Buffer buffer, deleter_t d) {
@@ -175,11 +171,10 @@ array::~array() {
return;
}
// Ignore arrays that might be detached during eval
if (status() == array::Status::scheduled) {
// Ignore arrays that will be detached
if (status() != array::Status::unscheduled) {
return;
}
// Break circular reference for non-detached arrays with siblings
if (auto n = siblings().size(); n > 0) {
bool do_detach = true;
@@ -211,7 +206,7 @@ void array::ArrayDesc::init() {
strides[i] = size;
size *= shape[i];
}
for (const auto& in : inputs) {
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
@@ -236,7 +231,7 @@ array::ArrayDesc::ArrayDesc(
array::ArrayDesc::~ArrayDesc() {
// When an array description is destroyed it will delete a bunch of arrays
// that may also destroy their corresponding descriptions and so on and so
// that may also destory their corresponding descriptions and so on and so
// forth.
//
// This calls recursively the destructor and can result in stack overflow, we

View File

@@ -73,32 +73,32 @@ class array {
this->array_desc_ = other.array_desc_;
}
return *this;
}
};
/** The size of the array's datatype in bytes. */
size_t itemsize() const {
return size_of(dtype());
}
};
/** The number of elements in the array. */
size_t size() const {
return array_desc_->size;
}
};
/** The number of bytes in the array. */
size_t nbytes() const {
return size() * itemsize();
}
};
/** The number of dimensions of the array. */
size_t ndim() const {
return array_desc_->shape.size();
}
};
/** The shape of the array as a vector of integers. */
const std::vector<int>& shape() const {
return array_desc_->shape;
}
};
/**
* Get the size of the corresponding dimension.
@@ -107,12 +107,12 @@ class array {
* bounds checking. */
int shape(int dim) const {
return shape().at(dim < 0 ? dim + ndim() : dim);
}
};
/** The strides of the array. */
const std::vector<size_t>& strides() const {
return array_desc_->strides;
}
};
/**
* Get the stride of the corresponding dimension.
@@ -121,12 +121,12 @@ class array {
* bounds checking. */
size_t strides(int dim) const {
return strides().at(dim < 0 ? dim + ndim() : dim);
}
};
/** Get the arrays data type. */
Dtype dtype() const {
return array_desc_->dtype;
}
};
/** Evaluate the array. */
void eval();
@@ -160,10 +160,10 @@ class array {
friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) {
return a.arr.id() == b.arr.id() && a.idx == b.idx;
}
};
friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) {
return !(a == b);
}
};
private:
const array& arr;
@@ -209,7 +209,7 @@ class array {
allocator::Buffer buffer;
deleter_t d;
Data(allocator::Buffer buffer, deleter_t d = allocator::free)
: buffer(buffer), d(d) {}
: buffer(buffer), d(d) {};
// Not copyable
Data(const Data& d) = delete;
Data& operator=(const Data& d) = delete;
@@ -230,41 +230,42 @@ class array {
/** The array's primitive. */
Primitive& primitive() const {
return *(array_desc_->primitive);
}
};
/** A shared pointer to the array's primitive. */
std::shared_ptr<Primitive>& primitive_ptr() const {
return array_desc_->primitive;
}
};
/** Check if the array has an attached primitive or is a leaf node. */
bool has_primitive() const {
return array_desc_->primitive != nullptr;
}
};
/** The array's inputs. */
const std::vector<array>& inputs() const {
return array_desc_->inputs;
}
};
std::vector<array>& inputs() {
return array_desc_->inputs;
}
/** True indicates the arrays buffer is safe to reuse */
bool is_donatable() const {
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
bool is_donatable(int known_instances = 1) const {
return array_desc_.use_count() == known_instances &&
(array_desc_->data.use_count() == 1);
}
/** The array's siblings. */
const std::vector<array>& siblings() const {
return array_desc_->siblings;
}
};
/** The array's siblings. */
std::vector<array>& siblings() {
return array_desc_->siblings;
}
};
void set_siblings(std::vector<array> siblings, uint16_t position) {
array_desc_->siblings = std::move(siblings);
@@ -281,7 +282,7 @@ class array {
outputs.push_back(*this);
outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end());
return outputs;
}
};
/** Detach the array from the graph. */
void detach();
@@ -289,19 +290,19 @@ class array {
/** Get the Flags bit-field. */
const Flags& flags() const {
return array_desc_->flags;
}
};
/** The size (in elements) of the underlying buffer the array points to. */
size_t data_size() const {
return array_desc_->data_size;
}
};
allocator::Buffer& buffer() {
return array_desc_->data->buffer;
}
};
const allocator::Buffer& buffer() const {
return array_desc_->data->buffer;
}
};
// Return a copy of the shared pointer
// to the array::Data struct
@@ -312,20 +313,19 @@ class array {
template <typename T>
T* data() {
return static_cast<T*>(array_desc_->data_ptr);
}
};
template <typename T>
const T* data() const {
return static_cast<T*>(array_desc_->data_ptr);
}
};
enum Status { unscheduled, scheduled, available };
bool is_available() const {
return status() == Status::available;
}
Status status() const {
const Status status() const {
return array_desc_->status;
}

View File

@@ -1,9 +1,9 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"

View File

@@ -2,7 +2,8 @@
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <vecLib/BNNS/bnns.h>
#include <vecLib/cblas_new.h>
#include "mlx/backend/accelerate/utils.h"
#include "mlx/backend/common/copy.h"

View File

@@ -3,7 +3,8 @@
#include <cassert>
#include <cmath>
#include <Accelerate/Accelerate.h>
#include <vecLib/vDSP.h>
#include <vecLib/vForce.h>
#include "mlx/allocator.h"
#include "mlx/backend/common/binary.h"
@@ -36,7 +37,7 @@ DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)
DEFAULT(Copy)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT_MULTI(DivMod)
DEFAULT(NumberOfElements)
@@ -50,7 +51,6 @@ DEFAULT(GatherMM)
DEFAULT(GatherQMM)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)
@@ -102,7 +102,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
binary(
a,
b,
out,
@@ -117,7 +117,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary_op<int>(
binary(
a,
b,
out,
@@ -132,7 +132,7 @@ void Add::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n);
});
} else {
eval(inputs, out);
binary(a, b, out, [](auto x, auto y) { return x + y; });
}
}
@@ -287,7 +287,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == int32) {
binary_op<int>(
binary(
a,
b,
out,
@@ -300,7 +300,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n);
});
} else if (a.dtype() == float32) {
binary_op<float>(
binary(
a,
b,
out,
@@ -315,7 +315,7 @@ void Divide::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else {
eval(inputs, out);
binary(a, b, out, [](auto x, auto y) { return x / y; });
}
}
@@ -326,8 +326,12 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
auto size = in.data_size();
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::exp(x); });
} else {
eval(inputs, out);
throw std::invalid_argument(
"[exp] Cannot exponentiate elements in array"
" with non floating point type.");
}
}
@@ -389,8 +393,12 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (issubdtype(out.dtype(), inexact)) {
unary_fp(in, out, [](auto x) { return std::log1p(x); });
} else {
eval(inputs, out);
throw std::invalid_argument(
"[log1p] Cannot compute log of elements in array with"
" non floating point type.");
}
}
@@ -400,7 +408,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
binary(
a,
b,
out,
@@ -415,7 +423,7 @@ void Multiply::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n);
});
} else {
eval(inputs, out);
binary(a, b, out, [](auto x, auto y) { return x * y; });
}
}
@@ -426,7 +434,7 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
set_unary_output_data(in, out);
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
} else {
eval(inputs, out);
unary(in, out, [](auto x) { return -x; });
}
}
@@ -513,7 +521,7 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
auto size = in.data_size();
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else {
eval(inputs, out);
unary(in, out, [](auto x) { return x * x; });
}
}
@@ -539,7 +547,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& b = inputs[1];
if (a.dtype() == float32) {
binary_op<float>(
binary(
a,
b,
out,
@@ -557,7 +565,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n);
});
} else if (a.dtype() == int32) {
binary_op<int>(
binary(
a,
b,
out,
@@ -569,7 +577,7 @@ void Subtract::eval_cpu(const std::vector<array>& inputs, array& out) {
},
UseDefaultBinaryOp());
} else {
eval(inputs, out);
binary(a, b, out, [](auto x, auto y) { return x - y; });
}
}

View File

@@ -2,8 +2,8 @@
#include <cassert>
#include <Accelerate/Accelerate.h>
#include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/reduce.h"
#include "mlx/primitives.h"

View File

@@ -3,10 +3,7 @@
#include <cassert>
#include <limits>
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include <arm_neon.h>
#endif
#include <simd/math.h>
#include <simd/vector.h>
@@ -56,26 +53,25 @@ inline simd_float16 simd_fast_exp(simd_float16 x) {
return (*(simd_float16*)&epart) * x;
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/**
* The ARM neon equivalent of the fast exp above.
*/
inline float16x8_t neon_fast_exp(float16x8_t x) {
x = vmulq_f16(x, vdupq_n_f16(float16_t(1.442695f))); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(float16_t(-14.f))); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(float16_t(14.f))); // clamp over with 14
x = vmulq_f16(x, vdupq_n_f16(1.442695)); // multiply with log_2(e)
x = vmaxq_f16(x, vdupq_n_f16(-14)); // clamp under with -14
x = vminq_f16(x, vdupq_n_f16(14)); // clamp over with 14
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(float16_t(0.5f))));
float16x8_t ipart = vrndmq_f16(vaddq_f16(x, vdupq_n_f16(0.5)));
float16x8_t fpart = vsubq_f16(x, ipart);
x = vdupq_n_f16(float16_t(1.535336188319500e-4f));
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.339887440266574e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(9.618437357674640e-3f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(5.550332471162809e-2f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(2.402264791363012e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(6.931472028550421e-1f)), x, fpart);
x = vfmaq_f16(vdupq_n_f16(float16_t(1.000000000000000f)), x, fpart);
x = vdupq_n_f16(1.535336188319500e-4f);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.339887440266574e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(9.618437357674640e-3f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(5.550332471162809e-2f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(2.402264791363012e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(6.931472028550421e-1f), x, fpart);
x = vfmaq_f16(vdupq_n_f16(1.000000000000000f), x, fpart);
// generate 2**ipart in the floating point representation using integer
// bitshifting
@@ -111,55 +107,6 @@ inline float16_t neon_reduce_add(float16x8_t x) {
return vget_lane_f16(y, 0);
}
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
}
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename T, typename VT>
struct AccelerateSimdOps {
VT init(T a) {
@@ -176,7 +123,7 @@ struct AccelerateSimdOps {
VT max(VT a, VT b) {
return simd_max(a, b);
}
};
VT exp(VT x) {
return simd_fast_exp(x);
@@ -207,6 +154,53 @@ struct AccelerateSimdOps {
}
};
template <typename T, typename VT>
struct NeonFp16SimdOps {
VT init(T a) {
return vdupq_n_f16(a);
}
VT load(const T* a) {
return vld1q_f16(a);
}
void store(T* dst, VT x) {
vst1q_f16(dst, x);
}
VT max(VT a, VT b) {
return vmaxq_f16(a, b);
};
VT exp(VT x) {
return neon_fast_exp(x);
}
VT add(VT a, VT b) {
return vaddq_f16(a, b);
}
VT sub(VT a, T b) {
return vsubq_f16(a, vdupq_n_f16(b));
}
VT mul(VT a, VT b) {
return vmulq_f16(a, b);
}
VT mul(VT a, T b) {
return vmulq_f16(a, vdupq_n_f16(b));
}
T reduce_max(VT x) {
return neon_reduce_max(x);
}
T reduce_add(VT x) {
return neon_reduce_add(x);
}
};
template <typename T, typename AccT, typename VT, typename Ops, int N>
void softmax(const array& in, array& out) {
Ops ops;
@@ -368,16 +362,12 @@ void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
AccelerateSimdOps<float, simd_float16>,
16>(in, out);
} else {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
softmax<
float16_t,
float16_t,
float16x8_t,
NeonFp16SimdOps<float16_t, float16x8_t>,
8>(in, out);
#else // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
eval(inputs, out); // Redirect to common backend for consistency
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
}
break;
case bfloat16:

View File

@@ -1,8 +1,8 @@
// Copyright © 2023-2024 Apple Inc.
// Copyright © 2023 Apple Inc.
#pragma once
#include <Accelerate/Accelerate.h>
#include <vecLib/BNNS/bnns.h>
#include "mlx/dtype.h"
namespace mlx::core {

View File

@@ -42,15 +42,12 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp

View File

@@ -196,20 +196,6 @@ void LogAddExp::eval(const std::vector<array>& inputs, array& out) {
}
}
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
}
void Maximum::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& a = inputs[0];

View File

@@ -66,7 +66,7 @@ void Copy::eval(const std::vector<array>& inputs, array& out) {
out.copy_shared_buffer(inputs[0]);
}
void CustomTransforms::eval(
void CustomVJP::eval(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() > outputs.size());
@@ -250,6 +250,49 @@ void Split::eval(
}
}
std::tuple<bool, int64_t, std::vector<int64_t>> Slice::prepare_slice(
const array& in) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices_[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides_[i];
copy_needed |= strides_[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void Slice::shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
const array& in) {
int64_t data_offset = 0;

View File

@@ -205,8 +205,8 @@ void compiled_allocate_outputs(
// - Donatable
// - Correct size
// - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
if (move_buffers) {
outputs[o].move_shared_buffer(

View File

@@ -4,7 +4,6 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
@@ -143,31 +142,29 @@ void copy_general(
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
int64_t i_offset) {
auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides});
switch (new_shape.size()) {
switch (src.ndim()) {
case 1:
copy_general_dim1<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
src, dst, data_shape, i_strides, i_offset);
return;
case 2:
copy_general_dim2<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
src, dst, data_shape, i_strides, i_offset);
return;
case 3:
copy_general_dim3<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
src, dst, data_shape, i_strides, i_offset);
return;
case 4:
copy_general_dim4<SrcT, DstT, stride_t>(
src, dst, new_shape, new_strides[0], i_offset);
src, dst, data_shape, i_strides, i_offset);
return;
}
auto src_ptr = src.data<SrcT>() + i_offset;
auto dst_ptr = dst.data<DstT>();
for (size_t i = 0; i < dst.size(); ++i) {
stride_t src_elem = elem_to_loc(i, new_shape, new_strides[0]);
stride_t src_elem = elem_to_loc(i, data_shape, i_strides);
dst_ptr[i] = static_cast<DstT>(src_ptr[src_elem]);
}
}
@@ -198,10 +195,10 @@ inline void copy_general_general_dims(
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
stride_t i_offset,
stride_t o_offset) {
if constexpr (D > 1) {
int axis = data_shape.size() - D;
int axis = src.ndim() - D;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
@@ -212,7 +209,7 @@ inline void copy_general_general_dims(
o_offset += stride_dst;
}
} else {
int axis = data_shape.size() - 1;
int axis = src.ndim() - 1;
auto stride_src = i_strides[axis];
auto stride_dst = o_strides[axis];
auto N = data_shape[axis];
@@ -233,76 +230,38 @@ void copy_general_general(
const std::vector<int>& data_shape,
const std::vector<stride_t>& i_strides,
const std::vector<stride_t>& o_strides,
int64_t i_offset,
int64_t o_offset) {
auto [new_shape, new_strides] = collapse_contiguous_dims(
data_shape, std::vector<std::vector<stride_t>>{i_strides, o_strides});
switch (new_shape.size()) {
stride_t i_offset,
stride_t o_offset) {
switch (src.ndim()) {
case 1:
copy_general_general_dims<SrcT, DstT, stride_t, 1>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 2:
copy_general_general_dims<SrcT, DstT, stride_t, 2>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 3:
copy_general_general_dims<SrcT, DstT, stride_t, 3>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 4:
copy_general_general_dims<SrcT, DstT, stride_t, 4>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
case 5:
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
i_offset,
o_offset);
src, dst, data_shape, i_strides, o_strides, i_offset, o_offset);
return;
}
int size = std::accumulate(
new_shape.end() - 5, new_shape.end(), 1, std::multiplies<int>());
data_shape.end() - 5, data_shape.end(), 1, std::multiplies<int>());
for (int i = 0; i < src.size(); i += size) {
stride_t src_offset = i_offset + elem_to_loc(i, new_shape, new_strides[0]);
stride_t dst_offset = o_offset + elem_to_loc(i, new_shape, new_strides[1]);
stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides);
stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides);
copy_general_general_dims<SrcT, DstT, stride_t, 5>(
src,
dst,
new_shape,
new_strides[0],
new_strides[1],
src_offset,
dst_offset);
src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset);
}
}
@@ -485,17 +444,8 @@ void copy_inplace(
}
}
template void copy_inplace<size_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
const std::vector<size_t>& i_strides,
const std::vector<size_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
template void copy_inplace<int64_t>(
template <>
void copy_inplace<int64_t>(
const array& src,
array& dst,
const std::vector<int>& data_shape,
@@ -503,6 +453,24 @@ template void copy_inplace<int64_t>(
const std::vector<int64_t>& o_strides,
int64_t i_offset,
int64_t o_offset,
CopyType ctype);
CopyType ctype) {
switch (ctype) {
case CopyType::General:
case CopyType::GeneralGeneral:
return copy_inplace_dispatch(
src,
dst,
ctype,
data_shape,
i_strides,
o_strides,
i_offset,
o_offset);
case CopyType::Scalar:
case CopyType::Vector:
return copy_inplace_dispatch(src, dst, ctype);
}
}
} // namespace mlx::core

View File

@@ -5,6 +5,7 @@
#else
#include <cblas.h>
#endif
#include <cstring>
#include "mlx/array.h"
@@ -52,7 +53,7 @@ DEFAULT(Convolution)
DEFAULT(Copy)
DEFAULT(Cos)
DEFAULT(Cosh)
DEFAULT_MULTI(CustomTransforms)
DEFAULT_MULTI(CustomVJP)
DEFAULT_MULTI(Depends)
DEFAULT(Divide)
DEFAULT(NumberOfElements)
@@ -68,7 +69,6 @@ DEFAULT(Full)
DEFAULT(Gather)
DEFAULT(Greater)
DEFAULT(GreaterEqual)
DEFAULT(Hadamard)
DEFAULT(Less)
DEFAULT(LessEqual)
DEFAULT(Load)

View File

@@ -1,107 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/primitives.h"
namespace mlx::core {
// n = 2^k component
template <typename T>
void hadamard_n(array& out, int n, int m, float scale) {
for (int b = 0; b < out.size() / n; b++) {
size_t loc = b * n;
T* data_ptr = out.data<T>() + loc;
int h = 1;
int n_over_2 = n / 2;
while (h < n) {
for (int i = 0; i < n / 2; i++) {
int k = i & (h - 1);
int j = ((i - k) << 1) + k;
float x = *(data_ptr + j);
float y = *(data_ptr + j + h);
*(data_ptr + j) = x + y;
*(data_ptr + j + h) = x - y;
if (h == n_over_2) {
*(data_ptr + j) *= scale;
*(data_ptr + j + h) *= scale;
}
}
h <<= 1;
}
}
}
// m component
template <typename T>
void hadamard_m(array& out, int n, int m, float scale) {
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
auto start = 1;
auto end = matrix.find('\n', start);
std::vector<bool> hmat_vec;
while (end != std::string_view::npos) {
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
hmat_vec.push_back(row[i] == '+');
}
start = end + 1;
end = matrix.find('\n', start);
}
for (int b = 0; b < out.size() / m / n; b++) {
size_t loc = b * n * m;
T* data_ptr = out.data<T>() + loc;
for (int i = 0; i < n; i++) {
std::vector<float> out(m);
for (int j = 0; j < m; j++) {
for (int k = 0; k < m; k++) {
float x = *(data_ptr + i + k * n);
if (hmat_vec[k + j * m]) {
out[j] += x;
} else {
out[j] -= x;
}
}
}
for (int j = 0; j < m; j++) {
*(data_ptr + i + j * n) = out[j] * scale;
}
}
}
}
template <typename T>
void hadamard(array& out, int n, int m, float scale) {
float n_scale = m > 1 ? 1.0 : scale;
hadamard_n<T>(out, n, m, n_scale);
if (m > 1) {
hadamard_m<T>(out, n, m, scale);
}
}
void Hadamard::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
// Copy input to output
copy(in, out, CopyType::General);
int axis = out.ndim() - 1;
auto [n, m] = decompose_hadamard(out.shape(axis));
switch (in.dtype()) {
case float32:
return hadamard<float>(out, n, m, scale_);
case float16:
return hadamard<float16_t>(out, n, m, scale_);
case bfloat16:
return hadamard<bfloat16_t>(out, n, m, scale_);
default:
throw std::invalid_argument("[hadamard] Unsupported type.");
}
}
} // namespace mlx::core

View File

@@ -1,105 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <map>
#include "mlx/utils.h"
namespace mlx::core {
// From http://neilsloane.com/hadamard/
constexpr std::string_view h12 = R"(
+-++++++++++
--+-+-+-+-+-
+++-++----++
+---+--+-++-
+++++-++----
+-+---+--+-+
++--+++-++--
+--++---+--+
++----+++-++
+--+-++---+-
++++----+++-
+-+--+-++---
)";
constexpr std::string_view h20 = R"(
+----+----++--++-++-
-+----+---+++---+-++
--+----+---+++-+-+-+
---+----+---+++++-+-
----+----++--++-++-+
-+++++-----+--+++--+
+-+++-+---+-+--+++--
++-++--+---+-+--+++-
+++-+---+---+-+--+++
++++-----++--+-+--++
--++-+-++-+-----++++
---++-+-++-+---+-+++
+---++-+-+--+--++-++
++---++-+----+-+++-+
-++---++-+----+++++-
-+--+--++-+----+----
+-+-----++-+----+---
-+-+-+---+--+----+--
--+-+++------+----+-
+--+--++------+----+
)";
constexpr std::string_view h28 = R"(
+------++----++-+--+-+--++--
-+-----+++-----+-+--+-+--++-
--+-----+++---+-+-+----+--++
---+-----+++---+-+-+-+--+--+
----+-----+++---+-+-+++--+--
-----+-----++++--+-+--++--+-
------++----++-+--+-+--++--+
--++++-+-------++--+++-+--+-
---++++-+-----+-++--+-+-+--+
+---+++--+----++-++--+-+-+--
++---++---+----++-++--+-+-+-
+++---+----+----++-++--+-+-+
++++--------+-+--++-++--+-+-
-++++--------+++--++--+--+-+
-+-++-++--++--+--------++++-
+-+-++--+--++--+--------++++
-+-+-++--+--++--+----+---+++
+-+-+-++--+--+---+---++---++
++-+-+-++--+------+--+++---+
-++-+-+-++--+------+-++++---
+-++-+---++--+------+-++++--
-++--++-+-++-+++----++------
+-++--++-+-++-+++-----+-----
++-++---+-+-++-+++-----+----
-++-++-+-+-+-+--+++-----+---
--++-++++-+-+----+++-----+--
+--++-+-++-+-+----+++-----+-
++--++-+-++-+-+----++------+
)";
inline const std::map<int, std::string_view> hadamard_matrices() {
return {{12, h12}, {20, h20}, {28, h28}};
}
inline std::pair<int, int> decompose_hadamard(int n) {
// n = m*2^k
int m = 1;
if (!is_power_of_2(n)) {
auto h_matrices = hadamard_matrices();
for (auto [factor, _] : h_matrices) {
if (n % factor == 0) {
m = factor;
n /= factor;
break;
}
}
if (m == 1) {
throw std::invalid_argument(
"[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28).");
}
}
return {n, m};
}
} // namespace mlx::core

View File

@@ -10,106 +10,9 @@
#include <lapack.h>
#endif
// Wrapper to account for differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int strtri_wrapper(char uplo, char diag, float* matrix, int N) {
int info;
#ifdef LAPACK_FORTRAN_STRLEN_END
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1),
/* diag_len = */ static_cast<size_t>(1));
#else
strtri_(
/* uplo = */ &uplo,
/* diag = */ &diag,
/* N = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif
return info;
}
namespace mlx::core {
void general_inv(array& inv, int N, int i) {
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void tri_inv(array& inv, int N, int i, bool upper) {
const char uplo = upper ? 'L' : 'U';
const char diag = 'N';
int info = strtri_wrapper(uplo, diag, inv.data<float>() + N * N * i, N);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: triangular inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
void inverse_impl(const array& a, array& inv) {
// Lapack uses the column-major convention. We take advantage of the following
// identity to avoid transposing (see
// https://math.stackexchange.com/a/340234):
@@ -121,11 +24,63 @@ void inverse_impl(const array& a, array& inv, bool tri, bool upper) {
const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);
int info;
auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)};
for (int i = 0; i < num_matrices; i++) {
if (tri) {
tri_inv(inv, N, i, upper);
} else {
general_inv(inv, N, i);
// Compute LU factorization.
sgetrf_(
/* m = */ &N,
/* n = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU factorization failed with error code " << info;
throw std::runtime_error(ss.str());
}
static const int lwork_query = -1;
float workspace_size = 0;
// Compute workspace size.
sgetri_(
/* m = */ &N,
/* a = */ nullptr,
/* lda = */ &N,
/* ipiv = */ nullptr,
/* work = */ &workspace_size,
/* lwork = */ &lwork_query,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: LU workspace calculation failed with error code "
<< info;
throw std::runtime_error(ss.str());
}
const int lwork = workspace_size;
auto scratch =
array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)};
// Compute inverse.
sgetri_(
/* m = */ &N,
/* a = */ inv.data<float>() + N * N * i,
/* lda = */ &N,
/* ipiv = */ static_cast<int*>(ipiv.buffer.raw_ptr()),
/* work = */ static_cast<float*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
/* info = */ &info);
if (info != 0) {
std::stringstream ss;
ss << "inverse_impl: inversion failed with error code " << info;
throw std::runtime_error(ss.str());
}
}
}
@@ -134,7 +89,7 @@ void Inverse::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Inverse::eval] only supports float32.");
}
inverse_impl(inputs[0], output, tri_, upper_);
inverse_impl(inputs[0], output);
}
} // namespace mlx::core

View File

@@ -28,7 +28,6 @@ const char* get_kernel_preamble() {
return R"preamble(
$INCLUDES
$CONTENT
using namespace mlx::core;
using namespace mlx::core::detail;
)preamble";
}

View File

@@ -108,105 +108,105 @@ struct Abs {
template <typename T>
T operator()(T x) {
return std::abs(x);
}
};
uint8_t operator()(uint8_t x) {
return x;
}
};
uint16_t operator()(uint16_t x) {
return x;
}
};
uint32_t operator()(uint32_t x) {
return x;
}
};
uint64_t operator()(uint64_t x) {
return x;
}
};
bool operator()(bool x) {
return x;
}
};
};
struct ArcCos {
template <typename T>
T operator()(T x) {
return std::acos(x);
}
};
};
struct ArcCosh {
template <typename T>
T operator()(T x) {
return std::acosh(x);
}
};
};
struct ArcSin {
template <typename T>
T operator()(T x) {
return std::asin(x);
}
};
};
struct ArcSinh {
template <typename T>
T operator()(T x) {
return std::asinh(x);
}
};
};
struct ArcTan {
template <typename T>
T operator()(T x) {
return std::atan(x);
}
};
};
struct ArcTan2 {
template <typename T>
T operator()(T y, T x) {
return std::atan2(y, x);
}
};
};
struct ArcTanh {
template <typename T>
T operator()(T x) {
return std::atanh(x);
}
};
};
struct Ceil {
template <typename T>
T operator()(T x) {
return std::ceil(x);
}
};
int8_t operator()(int8_t x) {
return x;
}
};
int16_t operator()(int16_t x) {
return x;
}
};
int32_t operator()(int32_t x) {
return x;
}
};
int64_t operator()(int64_t x) {
return x;
}
};
uint8_t operator()(uint8_t x) {
return x;
}
};
uint16_t operator()(uint16_t x) {
return x;
}
};
uint32_t operator()(uint32_t x) {
return x;
}
};
uint64_t operator()(uint64_t x) {
return x;
}
};
bool operator()(bool x) {
return x;
}
};
};
struct Conjugate {
@@ -219,35 +219,35 @@ struct Cos {
template <typename T>
T operator()(T x) {
return std::cos(x);
}
};
};
struct Cosh {
template <typename T>
T operator()(T x) {
return std::cosh(x);
}
};
};
struct Erf {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erf(static_cast<float>(x)));
}
};
};
struct ErfInv {
template <typename T>
T operator()(T x) {
return static_cast<T>(fast_erfinv(static_cast<float>(x)));
}
};
};
struct Exp {
template <typename T>
T operator()(T x) {
return fast_exp(x);
}
};
complex64_t operator()(complex64_t x) {
return std::exp(x);
@@ -258,83 +258,83 @@ struct Expm1 {
template <typename T>
T operator()(T x) {
return expm1(x);
}
};
};
struct Floor {
template <typename T>
T operator()(T x) {
return std::floor(x);
}
};
int8_t operator()(int8_t x) {
return x;
}
};
int16_t operator()(int16_t x) {
return x;
}
};
int32_t operator()(int32_t x) {
return x;
}
};
int64_t operator()(int64_t x) {
return x;
}
};
uint8_t operator()(uint8_t x) {
return x;
}
};
uint16_t operator()(uint16_t x) {
return x;
}
};
uint32_t operator()(uint32_t x) {
return x;
}
};
uint64_t operator()(uint64_t x) {
return x;
}
};
bool operator()(bool x) {
return x;
}
};
};
struct Log {
template <typename T>
T operator()(T x) {
return std::log(x);
}
};
};
struct Log2 {
template <typename T>
T operator()(T x) {
return std::log2(x);
}
};
};
struct Log10 {
template <typename T>
T operator()(T x) {
return std::log10(x);
}
};
};
struct Log1p {
template <typename T>
T operator()(T x) {
return log1p(x);
}
};
};
struct LogicalNot {
template <typename T>
T operator()(T x) {
return !x;
}
};
};
struct Negative {
template <typename T>
T operator()(T x) {
return -x;
}
};
};
struct Round {
@@ -379,49 +379,49 @@ struct Sin {
template <typename T>
T operator()(T x) {
return std::sin(x);
}
};
};
struct Sinh {
template <typename T>
T operator()(T x) {
return std::sinh(x);
}
};
};
struct Square {
template <typename T>
T operator()(T x) {
return x * x;
}
};
};
struct Sqrt {
template <typename T>
T operator()(T x) {
return std::sqrt(x);
}
};
};
struct Rsqrt {
template <typename T>
T operator()(T x) {
return static_cast<decltype(x)>(1.0) / std::sqrt(x);
}
};
};
struct Tan {
template <typename T>
T operator()(T x) {
return std::tan(x);
}
};
};
struct Tanh {
template <typename T>
T operator()(T x) {
return std::tanh(x);
}
};
};
struct Add {
@@ -554,7 +554,7 @@ struct LogAddExp {
? maxval
: static_cast<decltype(x)>(
maxval + std::log1p(fast_exp(minval - maxval)));
}
};
};
struct Multiply {
@@ -602,14 +602,14 @@ struct LogicalAnd {
template <typename T>
T operator()(T x, T y) {
return x && y;
}
};
};
struct LogicalOr {
template <typename T>
T operator()(T x, T y) {
return x || y;
}
};
};
struct Select {
@@ -623,35 +623,35 @@ struct BitwiseAnd {
template <typename T>
T operator()(T x, T y) {
return x & y;
}
};
};
struct BitwiseOr {
template <typename T>
T operator()(T x, T y) {
return x | y;
}
};
};
struct BitwiseXor {
template <typename T>
T operator()(T x, T y) {
return x ^ y;
}
};
};
struct LeftShift {
template <typename T>
T operator()(T x, T y) {
return x << y;
}
};
};
struct RightShift {
template <typename T>
T operator()(T x, T y) {
return x >> y;
}
};
};
} // namespace mlx::core::detail

View File

@@ -8,9 +8,9 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/arange.h"
#include "mlx/backend/common/binary.h"
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/ops.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/threefry.h"
#include "mlx/backend/common/unary.h"
#include "mlx/backend/common/utils.h"
@@ -313,6 +313,20 @@ void LogicalNot::eval(const std::vector<array>& inputs, array& out) {
unary(in, out, detail::LogicalNot());
}
void LogicalAnd::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalAnd requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalAnd());
}
void LogicalOr::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2); // LogicalOr requires two input arrays
auto& in1 = inputs[0];
auto& in2 = inputs[1];
binary(in1, in2, out, detail::LogicalOr());
}
void Negative::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
@@ -405,17 +419,7 @@ void Reshape::eval(const std::vector<array>& inputs, array& out) {
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
if (copy_necessary) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto out_strides = make_contiguous_strides<size_t>(in.shape());
copy_inplace<size_t>(
in,
out,
in.shape(),
in.strides(),
out_strides,
0,
0,
CopyType::General);
copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General);
} else {
shared_buffer_reshape(in, out_strides, out);
}
@@ -488,8 +492,7 @@ void Slice::eval(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
// Calculate out strides, initial offset and if copy needs to be made
auto [copy_needed, data_offset, inp_strides] =
prepare_slice(in, start_indices_, strides_);
auto [copy_needed, data_offset, inp_strides] = prepare_slice(in);
// Do copy if needed
if (copy_needed) {
@@ -587,36 +590,4 @@ void Tanh::eval(const std::vector<array>& inputs, array& out) {
}
}
void View::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
auto obytes = size_of(out.dtype());
// Conditions for buffer copying (disjunction):
// - type size is the same
// - type size is smaller and the last axis is contiguous
// - the entire array is row contiguous
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
in.flags().row_contiguous) {
auto strides = in.strides();
for (int i = 0; i < strides.size() - 1; ++i) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * obytes / ibytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
copy_inplace(in, tmp, CopyType::General);
auto flags = out.flags();
flags.contiguous = true;
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
}
}
} // namespace mlx::core

View File

@@ -104,14 +104,48 @@ void reduce_dispatch_out(
}
case Reduce::Sum: {
auto op = [](auto y, auto x) { (*y) = (*y) + x; };
if (out.dtype() == int32) {
// special case since the input type can be bool
reduction_op<InT, int32_t>(in, out, axes, 0, op);
} else {
reduction_op<InT, InT>(in, out, axes, 0, op);
switch (out.dtype()) {
case bool_:
reduction_op<InT, bool>(in, out, axes, false, op);
break;
case uint8:
reduction_op<InT, uint8_t>(in, out, axes, 0, op);
break;
case uint16:
reduction_op<InT, uint16_t>(in, out, axes, 0, op);
break;
case uint32:
reduction_op<InT, uint32_t>(in, out, axes, 0, op);
break;
case uint64:
reduction_op<InT, uint64_t>(in, out, axes, 0, op);
break;
case int8:
reduction_op<InT, int8_t>(in, out, axes, 0, op);
break;
case int16:
reduction_op<InT, int16_t>(in, out, axes, 0, op);
break;
case int32:
reduction_op<InT, int32_t>(in, out, axes, 0, op);
break;
case int64:
reduction_op<InT, int64_t>(in, out, axes, 0, op);
break;
case float16:
reduction_op<InT, float16_t>(in, out, axes, 0.0f, op);
break;
case float32:
reduction_op<InT, float>(in, out, axes, 0.0f, op);
break;
case bfloat16:
reduction_op<InT, bfloat16_t>(in, out, axes, 0.0f, op);
break;
case complex64:
reduction_op<InT, complex64_t>(in, out, axes, complex64_t{0.0f}, op);
break;
}
break;
}
} break;
case Reduce::Prod: {
auto op = [](auto y, auto x) { (*y) *= x; };
reduction_op<InT, InT>(in, out, axes, 1, op);
@@ -134,29 +168,6 @@ void reduce_dispatch_out(
} // namespace
void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}
}
};
loop_inner(0, 0);
}
void Reduce::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];

View File

@@ -49,18 +49,47 @@ struct ReductionPlan {
ReductionPlan(ReductionOpType type_) : type(type_) {}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes);
namespace {
// Helper for the ndimensional strided loop
// Should this be in utils?
void nd_loop(
inline void nd_loop(
std::function<void(int)> callback,
const std::vector<int>& shape,
const std::vector<size_t>& strides);
const std::vector<size_t>& strides) {
std::function<void(int, int)> loop_inner;
loop_inner = [&](int dim, int offset) {
if (dim < shape.size() - 1) {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
loop_inner(dim + 1, offset + i * stride);
}
} else {
int size = shape[dim];
size_t stride = strides[dim];
for (int i = 0; i < size; i++) {
callback(offset + i * stride);
}
}
};
loop_inner(0, 0);
}
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
const std::vector<int>& axes) {
std::vector<int> shape = x.shape();
std::vector<size_t> strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
strides.erase(strides.begin() + a);
}
return std::make_pair(shape, strides);
}
template <typename T, typename U, typename Op>
struct DefaultStridedReduce {
@@ -94,6 +123,102 @@ struct DefaultContiguousReduce {
}
};
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
}
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
std::vector<int> shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1]) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
} else {
shape.push_back(x.shape(axes[i]));
strides.push_back(x.strides()[axes[i]]);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
return ReductionPlan(ContiguousStridedReduce, shape, strides);
}
}
// Let's check if we can optimize our access patterns
//
// 1. We have a reduction axis with stride 1. Simply call
// GeneralContiguousReduce and be done with it.
// 2. We have transpositions and we are not reducing over the axis with
// stride 1. However, we are reducing over an axis where everything is
// contiguous in memory to the right of that axis. We can call strided
// reduce and be done with it.
// 2. We have weird transpositions and expands. Copy the strides to the
// output, then call strided reduce.
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions;
for (auto a : axes) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
return a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
for (int i = reductions.size() - 1; i >= 1; i--) {
auto a = reductions[i];
auto b = reductions[i - 1];
// b.stride = a.shape * a.stride then a and b are contiguous
if (b.second == a.first * a.second) {
reductions.erase(reductions.begin() + i);
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
}
}
std::vector<int> shape;
std::vector<size_t> strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
}
// We can call the contiguous reduction op for every weird way the input is
// structured in the rest of the axes.
if (strides.back() == 1) {
return ReductionPlan(GeneralContiguousReduce, shape, strides);
}
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int size = 1;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
if (x.strides()[i] != size) {
break;
}
size *= x.shape(i);
}
if (size >= strides.back()) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}
return ReductionPlan(GeneralReduce, shape, strides);
}
template <typename T, typename U, typename OpS, typename OpC, typename Op>
void reduction_op(
const array& x,
@@ -236,4 +361,6 @@ void reduction_op(
reduction_op<T, U>(x, out, axes, init, ops, opc, op);
}
} // namespace
} // namespace mlx::core

View File

@@ -1,118 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/reduce.h"
namespace mlx::core {
std::pair<std::vector<int>, std::vector<size_t>> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
std::vector<int> shape = x.shape();
std::vector<size_t> strides = x.strides();
for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
strides.erase(strides.begin() + a);
}
return std::make_pair(shape, strides);
}
ReductionPlan get_reduction_plan(const array& x, const std::vector<int> axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
x.flags().contiguous) {
return ContiguousAllReduce;
}
// Row contiguous input so the output is row contiguous
if (x.flags().row_contiguous) {
// Merge consecutive axes
std::vector<int> shape = {x.shape(axes[0])};
std::vector<size_t> strides = {x.strides()[axes[0]]};
for (int i = 1; i < axes.size(); i++) {
if (axes[i] - 1 == axes[i - 1]) {
shape.back() *= x.shape(axes[i]);
strides.back() = x.strides()[axes[i]];
} else {
shape.push_back(x.shape(axes[i]));
strides.push_back(x.strides()[axes[i]]);
}
}
if (strides.back() == 1) {
return ReductionPlan(ContiguousReduce, shape, strides);
} else if (strides.back() > 1) {
return ReductionPlan(ContiguousStridedReduce, shape, strides);
}
}
// Let's check if we can optimize our access patterns
//
// 1. We have a reduction axis with stride 1. Simply call
// GeneralContiguousReduce and be done with it.
// 2. We have transpositions and we are not reducing over the axis with
// stride 1. However, we are reducing over an axis where everything is
// contiguous in memory to the right of that axis. We can call strided
// reduce and be done with it.
// 2. We have weird transpositions and expands. Copy the strides to the
// output, then call strided reduce.
// Sort reduction axes by stride in order to merge them and figure out if we
// have a contiguous reduction.
std::vector<std::pair<int, size_t>> reductions;
for (auto a : axes) {
reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
}
std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
return a.second > b.second;
});
// Extract the two smallest and try to merge them in case the contiguous
// reduction can be bigger than just the last axis.
for (int i = reductions.size() - 1; i >= 1; i--) {
auto a = reductions[i];
auto b = reductions[i - 1];
// b.stride = a.shape * a.stride then a and b are contiguous
if (b.second == a.first * a.second) {
reductions.erase(reductions.begin() + i);
reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
}
}
std::vector<int> shape;
std::vector<size_t> strides;
for (auto r : reductions) {
shape.push_back(r.first);
strides.push_back(r.second);
}
// We can call the contiguous reduction op for every weird way the input is
// structured in the rest of the axes.
if (strides.back() == 1) {
return ReductionPlan(GeneralContiguousReduce, shape, strides);
}
// Delegate to the general strided reduction op if the axes after
// strides.back() are contiguous.
if (strides.back() > 1) {
int size = 1;
for (int i = x.ndim() - 1; i >= 0; i--) {
if (axes.back() == i) {
continue;
}
if (x.strides()[i] != size) {
break;
}
size *= x.shape(i);
}
if (size >= strides.back()) {
return ReductionPlan(GeneralStridedReduce, shape, strides);
}
}
return ReductionPlan(GeneralReduce, shape, strides);
}
} // namespace mlx::core

View File

@@ -234,7 +234,7 @@ void scan_dispatch(
auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; };
auto init = (issubdtype(input.dtype(), floating))
? static_cast<U>(-std::numeric_limits<float>::infinity())
: std::numeric_limits<U>::min();
: std::numeric_limits<U>::max();
auto opcs = DefaultContiguousScan<T, U, decltype(op)>(op, init);
auto opss = DefaultStridedScan<T, U, decltype(op)>(op, init);
scan_op<T, U>(opcs, opss, input, output, axis, reverse, inclusive);

View File

@@ -1,52 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/common/utils.h"
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
std::vector<int>& start_indices,
std::vector<int>& strides) {
int64_t data_offset = 0;
bool copy_needed = false;
std::vector<int64_t> inp_strides(in.ndim(), 0);
for (int i = 0; i < in.ndim(); ++i) {
data_offset += start_indices[i] * in.strides()[i];
inp_strides[i] = in.strides()[i] * strides[i];
copy_needed |= strides[i] < 0;
}
return std::make_tuple(copy_needed, data_offset, inp_strides);
}
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out) {
// Compute row/col contiguity
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(out.shape(), out_strides);
auto flags = in.flags();
flags.row_contiguous = is_row_contiguous;
flags.col_contiguous = is_col_contiguous;
if (data_size == 1) {
// Broadcasted scalar array is contiguous.
flags.contiguous = true;
} else if (data_size == in.data_size()) {
// Means we sliced a broadcasted dimension so leave the "no holes" flag
// alone.
} else {
// We sliced something. So either we are row or col contiguous or we
// punched a hole.
flags.contiguous &= flags.row_contiguous || flags.col_contiguous;
}
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
}
} // namespace mlx::core

View File

@@ -1,20 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
std::tuple<bool, int64_t, std::vector<int64_t>> prepare_slice(
const array& in,
std::vector<int>& start_indices,
std::vector<int>& strides);
void shared_buffer_slice(
const array& in,
const std::vector<size_t>& out_strides,
size_t data_offset,
array& out);
} // namespace mlx::core

View File

@@ -113,14 +113,14 @@ void sort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = out.shape();
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = out.strides();
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = out.strides()[axis];
int axis_size = out.shape(axis);
size_t axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
// Perform sorting in place
for (int i = 0; i < n_rows; i++) {
@@ -143,42 +143,34 @@ void argsort(const array& in, array& out, int axis) {
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto in_remaining_shape = in.shape();
in_remaining_shape.erase(in_remaining_shape.begin() + axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto in_remaining_strides = in.strides();
in_remaining_strides.erase(in_remaining_strides.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
auto out_remaining_shape = out.shape();
out_remaining_shape.erase(out_remaining_shape.begin() + axis);
auto out_remaining_strides = out.strides();
out_remaining_strides.erase(out_remaining_strides.begin() + axis);
size_t in_stride = in.strides()[axis];
size_t out_stride = out.strides()[axis];
size_t axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
// Perform sorting
for (int i = 0; i < n_rows; i++) {
size_t in_loc = elem_to_loc(i, in_remaining_shape, in_remaining_strides);
size_t out_loc = elem_to_loc(i, out_remaining_shape, out_remaining_strides);
const T* data_ptr = in.data<T>() + in_loc;
IdxT* idx_ptr = out.data<IdxT>() + out_loc;
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
const T* data_ptr = in.data<T>() + loc;
IdxT* idx_ptr = out.data<IdxT>() + loc;
StridedIterator st_(idx_ptr, out_stride, 0);
StridedIterator ed_(idx_ptr, out_stride, axis_size);
StridedIterator st_(idx_ptr, axis_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, out_stride, 0);
StridedIterator ed(idx_ptr, out_stride, axis_size);
StridedIterator st(idx_ptr, axis_stride, 0);
StridedIterator ed(idx_ptr, axis_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, in_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * in_stride];
auto v2 = data_ptr[b * in_stride];
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride];
auto v2 = data_ptr[b * axis_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}

View File

@@ -29,15 +29,6 @@ inline size_t elem_to_loc(int elem, const array& a) {
return elem_to_loc(elem, a.shape(), a.strides());
}
template <typename stride_t>
std::vector<stride_t> make_contiguous_strides(const std::vector<int>& shape) {
std::vector<stride_t> strides(shape.size(), 1);
for (int i = shape.size() - 1; i > 0; i--) {
strides[i - 1] = strides[i] * shape[i];
}
return strides;
}
// Collapse dims that are contiguous to possibly route to a better kernel
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
// should return {{2, 4}, {{1, 2}}}.

View File

@@ -18,7 +18,7 @@ function(make_jit_source SRC_FILE)
${CMAKE_C_COMPILER}
${PROJECT_SOURCE_DIR}
${SRC_FILE}
"-DMLX_METAL_VERSION=${MLX_METAL_VERSION}"
"-D${MLX_METAL_VERSION}"
DEPENDS make_compiled_preamble.sh
kernels/${SRC_FILE}.h
${ARGN}
@@ -52,7 +52,6 @@ make_jit_source(
)
make_jit_source(scatter)
make_jit_source(gather)
make_jit_source(hadamard)
if (MLX_METAL_JIT)
target_sources(
@@ -65,11 +64,6 @@ if (MLX_METAL_JIT)
make_jit_source(unary)
make_jit_source(binary)
make_jit_source(binary_two)
make_jit_source(
fft
kernels/fft/radix.h
kernels/fft/readwrite.h
)
make_jit_source(ternary)
make_jit_source(softmax)
make_jit_source(scan)
@@ -113,8 +107,6 @@ if (MLX_METAL_JIT)
kernels/steel/defines.h
kernels/steel/conv/loaders/loader_general.h
)
make_jit_source(quantized)
make_jit_source(gemv_masked)
else()
target_sources(
mlx
@@ -134,7 +126,6 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
@@ -144,13 +135,11 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/normalization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
)
if (NOT MLX_METAL_PATH)

View File

@@ -242,17 +242,8 @@ void MetalAllocator::free(Buffer buffer) {
}
MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
// not be called on exit and all the buffers will be leaked. This is necessary
// because releasing buffers can take more than 30sec when the program holds a
// lot of RAM (for example inferencing a LLM), and it would feel frozen to
// users when exiting.
// TODO(zcbenz): Consider using the `base::NoDestructor` class from Chromium
// when applying this pattern to more places, or when introducing sanitizers
// to MLX.
// https://source.chromium.org/chromium/chromium/src/+/main:base/no_destructor.h
static MetalAllocator* allocator_ = new MetalAllocator;
return *allocator_;
static MetalAllocator allocator_;
return allocator_;
}
size_t set_cache_limit(size_t limit) {

View File

@@ -6,62 +6,20 @@
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#define BINARY_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
binary_op_gpu(inputs, out, get_primitive_string(this)); \
}
#define BINARY_GPU_MULTI(func) \
void func::eval_gpu( \
const std::vector<array>& inputs, std::vector<array>& outputs) { \
binary_op_gpu(inputs, outputs, get_primitive_string(this)); \
}
namespace mlx::core {
constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5;
std::string get_kernel_name(
BinaryOpType bopt,
const std::string& op,
const array& a,
bool use_2d,
int ndim) {
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << (use_2d ? "sv2" : "sv");
break;
case BinaryOpType::VectorScalar:
kname << (use_2d ? "vs2" : "vs");
break;
case BinaryOpType::VectorVector:
kname << (use_2d ? "vv2" : "vv");
break;
case BinaryOpType::General:
kname << "g";
if (ndim <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << ndim;
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
return kname.str();
}
void binary_op_gpu_inplace(
void binary_op(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const Stream& s) {
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
auto& out = outputs[0];
if (out.size() == 0) {
@@ -74,12 +32,39 @@ void binary_op_gpu_inplace(
auto& strides_b = strides[1];
auto& strides_out = strides[2];
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel =
get_binary_two_kernel(d, kernel_name, a.dtype(), outputs[0].dtype(), op);
auto kernel = get_binary_two_kernel(d, kernel_name, a, outputs[0]);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
@@ -123,11 +108,9 @@ void binary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
// Launch a 1D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d
? get_2d_grid_dims(outputs[0].shape(), outputs[0].strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
@@ -137,36 +120,15 @@ void binary_op_gpu_inplace(
}
}
void binary_op_gpu(
void binary_op(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const Stream& s) {
array& out,
const std::string op) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
binary_op_gpu_inplace(inputs, outputs, op, s);
}
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op) {
auto& s = outputs[0].primitive().stream();
binary_op_gpu(inputs, outputs, op, s);
}
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
if (out.size() == 0) {
return;
}
@@ -177,11 +139,39 @@ void binary_op_gpu_inplace(
auto& strides_b = strides[1];
auto& strides_out = strides[2];
bool use_2d = out.data_size() > UINT32_MAX;
std::string kernel_name = get_kernel_name(bopt, op, a, use_2d, shape.size());
std::string kernel_name;
{
std::ostringstream kname;
switch (bopt) {
case BinaryOpType::ScalarScalar:
kname << "ss";
break;
case BinaryOpType::ScalarVector:
kname << "sv";
break;
case BinaryOpType::VectorScalar:
kname << "vs";
break;
case BinaryOpType::VectorVector:
kname << "vv";
break;
case BinaryOpType::General:
kname << "g";
if (shape.size() <= MAX_BINARY_SPECIALIZED_DIMS) {
kname << shape.size();
} else {
kname << "n";
}
break;
}
kname << op << type_to_name(a);
kernel_name = kname.str();
}
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
auto kernel = get_binary_kernel(d, kernel_name, a.dtype(), out.dtype(), op);
auto kernel = get_binary_kernel(d, kernel_name, a, out);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
@@ -218,11 +208,10 @@ void binary_op_gpu_inplace(
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
// Launch a 1D or 2D grid of threads
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
// Launch a 1D grid of threads
size_t nthreads =
bopt == BinaryOpType::General ? out.size() : out.data_size();
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;
@@ -232,65 +221,102 @@ void binary_op_gpu_inplace(
}
}
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s) {
assert(inputs.size() == 2);
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
binary_op_gpu_inplace(inputs, out, op, s);
void Add::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "add");
}
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op) {
auto& s = out.primitive().stream();
binary_op_gpu(inputs, out, op, s);
void ArcTan2::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "arctan2");
}
BINARY_GPU(Add)
BINARY_GPU(ArcTan2)
BINARY_GPU(Divide)
BINARY_GPU_MULTI(DivMod)
BINARY_GPU(Remainder)
BINARY_GPU(Equal)
BINARY_GPU(Greater)
BINARY_GPU(GreaterEqual)
BINARY_GPU(Less)
BINARY_GPU(LessEqual)
BINARY_GPU(LogicalAnd)
BINARY_GPU(LogicalOr)
BINARY_GPU(LogAddExp)
BINARY_GPU(Maximum)
BINARY_GPU(Minimum)
BINARY_GPU(Multiply)
BINARY_GPU(NotEqual)
BINARY_GPU(Power)
BINARY_GPU(Subtract)
void BitwiseBinary::eval_gpu(const std::vector<array>& inputs, array& out) {
switch (op_) {
case BitwiseBinary::And:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op(inputs, out, "bitwise_and");
break;
case BitwiseBinary::Or:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op(inputs, out, "bitwise_or");
break;
case BitwiseBinary::Xor:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op(inputs, out, "bitwise_xor");
break;
case BitwiseBinary::LeftShift:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op(inputs, out, "left_shift");
break;
case BitwiseBinary::RightShift:
binary_op_gpu(inputs, out, get_primitive_string(this));
binary_op(inputs, out, "right_shift");
break;
}
}
void Divide::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "div");
}
void DivMod::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
binary_op(inputs, outputs, "divmod");
}
void Remainder::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "rem");
}
void Equal::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, equal_nan_ ? "naneq" : "eq");
}
void Greater::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "ge");
}
void GreaterEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "geq");
}
void Less::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "le");
}
void LessEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "leq");
}
void LogicalAnd::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "land");
}
void LogicalOr::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lor");
}
void LogAddExp::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "lae");
}
void Maximum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "max");
}
void Minimum::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "min");
}
void Multiply::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "mul");
}
void NotEqual::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "neq");
}
void Power::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "pow");
}
void Subtract::eval_gpu(const std::vector<array>& inputs, array& out) {
binary_op(inputs, out, "sub");
}
} // namespace mlx::core

View File

@@ -1,33 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void binary_op_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const Stream& s);
void binary_op_gpu(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
std::vector<array>& outputs,
const std::string& op,
const Stream& s);
void binary_op_gpu_inplace(
const std::vector<array>& inputs,
array& out,
const std::string& op,
const Stream& s);
} // namespace mlx::core

View File

@@ -56,15 +56,12 @@ inline void build_kernel(
} else {
add_indices = true;
os << " device const " << get_type_string(x.dtype()) << "* " << xname
<< " [[buffer(" << cnt++ << ")]]," << std::endl;
<< " [[buffer(" << cnt++ << ")]]," << std::endl
<< " constant const size_t* " << xname << "_strides [[buffer("
<< cnt++ << ")]]," << std::endl;
}
}
if (add_indices) {
os << " constant const size_t* in_strides [[buffer(" << cnt++
<< ")]],\n";
}
// Add the output arguments
for (auto& x : outputs) {
os << " device " << get_type_string(x.dtype()) << "* "
@@ -113,17 +110,13 @@ inline void build_kernel(
}
// Read the inputs in tmps
int nc_in_count = 0;
for (int i = 0; i < inputs.size(); ++i) {
auto& x = inputs[i];
for (auto& x : inputs) {
auto& xname = namer.get_name(x);
if (is_constant(x)) {
auto type_str = get_type_string(x.dtype());
os << " auto tmp_" << xname << " = static_cast<"
<< get_type_string(x.dtype()) << ">(";
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x);
os << ");" << std::endl;
os << ";" << std::endl;
} else if (is_scalar(x)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[0];" << std::endl;
@@ -131,20 +124,17 @@ inline void build_kernel(
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[index];" << std::endl;
} else if (!dynamic_dims) {
int offset = nc_in_count * ndim;
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[";
os << "index_0 * " << "in_strides[" << offset << "]";
os << "index_0 * " << xname << "_strides[0]";
for (int i = 1; i < ndim; i++) {
os << " + index_" << i << " * " << "in_strides[" << offset + i << "]";
os << " + index_" << i << " * " << xname << "_strides[" << i << "]";
}
os << "];" << std::endl;
nc_in_count++;
} else {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "
<< xname << "[elem_to_loc(index, output_shape, in_strides + "
<< nc_in_count * ndim << ", ndim)];" << std::endl;
nc_in_count++;
<< xname << "[elem_to_loc(index, output_shape, " << xname
<< "_strides, ndim)];" << std::endl;
}
}
@@ -306,7 +296,6 @@ void Compiled::eval_gpu(
// Put the inputs in
int cnt = 0;
int stride_idx = 1; // idx 0 is the output strides
std::vector<size_t> in_strides;
for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
@@ -314,17 +303,13 @@ void Compiled::eval_gpu(
auto& x = inputs[i];
compute_encoder.set_input_array(x, cnt++);
if (!contiguous && !is_scalar(x)) {
in_strides.insert(
in_strides.end(),
strides[stride_idx].begin(),
strides[stride_idx].end());
compute_encoder->setBytes(
strides[stride_idx].data(),
strides[stride_idx].size() * sizeof(size_t),
cnt++);
stride_idx++;
}
}
if (!in_strides.empty()) {
compute_encoder->setBytes(
in_strides.data(), in_strides.size() * sizeof(size_t), cnt++);
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, true);

View File

@@ -33,6 +33,9 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
if (out.size() == 0) {
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;
}
@@ -54,27 +57,22 @@ void copy_gpu_inplace(
int64_t out_offset,
CopyType ctype,
const Stream& s) {
if (out.size() == 0) {
return;
}
// Try to collapse contiguous dims
auto [shape, strides] = collapse_contiguous_dims(
data_shape, std::vector{strides_in_pre, strides_out_pre});
auto& strides_in_ = strides[0];
auto& strides_out_ = strides[1];
bool use_2d = out.data_size() > UINT32_MAX;
auto& d = metal::device(s.device);
std::string kernel_name;
{
std::ostringstream kname;
switch (ctype) {
case CopyType::Scalar:
kname << (use_2d ? "s2" : "s");
kname << "s";
break;
case CopyType::Vector:
kname << (use_2d ? "v2" : "v");
kname << "v";
break;
case CopyType::General:
kname << "g";
@@ -140,8 +138,7 @@ void copy_gpu_inplace(
compute_encoder.dispatchThreads(grid_dims, group_dims);
} else {
size_t nthreads = out.data_size();
MTL::Size grid_dims = use_2d ? get_2d_grid_dims(out.shape(), out.strides())
: MTL::Size(nthreads, 1, 1);
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size > nthreads) {
thread_group_size = nthreads;

View File

@@ -14,6 +14,7 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/backend/metal/mps/gemm.h"
#include "mlx/backend/metal/utils.h"
namespace fs = std::filesystem;
@@ -29,29 +30,13 @@ constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH;
constexpr auto get_metal_version() {
#if (MLX_METAL_VERSION >= 320)
return MTL::LanguageVersion3_2;
#elif (MLX_METAL_VERSION >= 310)
#if defined METAL_3_1
return MTL::LanguageVersion3_1;
#else
return MTL::LanguageVersion3_0;
#endif
}
std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
auto load_device() {
auto devices = MTL::CopyAllDevices();
auto device = static_cast<MTL::Device*>(devices->object(0))
@@ -139,49 +124,6 @@ MTL::Library* load_library(
} // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
CommandEncoder::~CommandEncoder() {
enc->endEncoding();
enc->release();
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void CommandEncoder::set_output_array(
array& a,
int idx,
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims,
MTL::Size group_dims) {
@@ -311,9 +253,13 @@ void Device::register_library(
}
}
void Device::register_library(const std::string& lib_name) {
void Device::register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
register_library(lib_name, get_colocated_mtllib_path(lib_name));
std::string new_lib_path = lib_path_func(lib_name);
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
library_map_.insert({lib_name, new_lib});
}
}
@@ -323,7 +269,7 @@ MTL::Library* Device::get_library_cache_(const std::string& lib_name) {
if (auto it = library_map_.find(lib_name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name, get_colocated_mtllib_path(lib_name));
register_library(lib_name);
mtl_lib = library_map_[lib_name];
}

View File

@@ -9,16 +9,38 @@
#include <unordered_map>
#include <unordered_set>
#include <dlfcn.h>
#include <filesystem>
#include "mlx/array.h"
#include "mlx/device.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
Dl_info info;
std::string mtllib_path;
std::string lib_ext = lib_name + ".metallib";
int success = dladdr((void*)get_colocated_mtllib_path, &info);
if (success) {
auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext;
mtllib_path = mtllib.c_str();
}
return mtllib_path;
}
using MTLFCList =
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
struct CommandEncoder {
CommandEncoder(MTL::CommandBuffer* cbuf);
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
};
CommandEncoder(const CommandEncoder&) = delete;
CommandEncoder& operator=(const CommandEncoder&) = delete;
@@ -41,8 +63,34 @@ struct CommandEncoder {
return enc;
}
void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0);
void set_input_array(const array& a, int idx, int64_t offset = 0) {
auto r_buf =
static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs.find(r_buf); it != outputs.end()) {
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
}
void set_output_array(array& a, int idx, int64_t offset = 0) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) {
concurrent_outputs.insert(buf);
} else {
outputs.insert(buf);
}
}
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
@@ -50,7 +98,10 @@ struct CommandEncoder {
return ConcurrentContext(*this);
}
~CommandEncoder();
~CommandEncoder() {
enc->endEncoding();
enc->release();
}
private:
void maybe_split();
@@ -85,8 +136,10 @@ class Device {
void register_library(
const std::string& lib_name,
const std::string& lib_path);
void register_library(const std::string& lib_name);
void register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func =
get_colocated_mtllib_path);
MTL::Library* get_library(const std::string& name);

View File

@@ -1,803 +1,106 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include <complex>
#include <map>
#include <numeric>
#include <set>
#include "mlx/3rdparty/pocketfft.h"
#include "mlx/backend/metal/binary.h"
// Copyright © 2023 Apple Inc.
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/unary.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/utils.h"
#include "mlx/mlx.h"
#include "mlx/primitives.h"
namespace mlx::core {
using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
#define MAX_STOCKHAM_FFT_SIZE 4096
#define MAX_RADER_FFT_SIZE 2048
#define MAX_BLUESTEIN_FFT_SIZE 2048
// Threadgroup memory batching improves throughput for small n
#define MIN_THREADGROUP_MEM_SIZE 256
// For strided reads/writes, coalesce at least this many complex64s
#define MIN_COALESCE_WIDTH 4
inline const std::vector<int> supported_radices() {
// Ordered by preference in decomposition.
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
}
std::vector<int> prime_factors(int n) {
int z = 2;
std::vector<int> factors;
while (z * z <= n) {
if (n % z == 0) {
factors.push_back(z);
n /= z;
} else {
z++;
}
}
if (n > 1) {
factors.push_back(n);
}
return factors;
}
struct FourStepParams {
bool required = false;
bool first_step = true;
int n1 = 0;
int n2 = 0;
};
// Forward Declaration
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
const FourStepParams four_step_params,
bool inplace,
const Stream& s);
struct FFTPlan {
int n = 0;
// Number of steps for each radix in the Stockham decomposition
std::vector<int> stockham;
// Number of steps for each radix in the Rader decomposition
std::vector<int> rader;
// Rader factor, 1 if no rader factors
int rader_n = 1;
int bluestein_n = -1;
// Four step FFT
bool four_step = false;
int n1 = 0;
int n2 = 0;
};
int next_fast_n(int n) {
return next_power_of_2(n);
}
std::vector<int> plan_stockham_fft(int n) {
auto radices = supported_radices();
std::vector<int> plan(radices.size(), 0);
int orig_n = n;
if (n == 1) {
return plan;
}
for (int i = 0; i < radices.size(); i++) {
int radix = radices[i];
// Manually tuned radices for powers of 2
if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
continue;
}
while (n % radix == 0) {
plan[i] += 1;
n /= radix;
if (n == 1) {
return plan;
}
}
}
throw std::runtime_error("Unplannable");
}
FFTPlan plan_fft(int n) {
auto radices = supported_radices();
std::set<int> radices_set(radices.begin(), radices.end());
FFTPlan plan;
plan.n = n;
plan.rader = std::vector<int>(radices.size(), 0);
auto factors = prime_factors(n);
int remaining_n = n;
// Four Step FFT when N is too large for shared mem.
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
// For power's of two we have a fast, no transpose four step implementation.
plan.four_step = true;
// Rough heuristic for choosing faster powers of two when we can
plan.n2 = n > 65536 ? 1024 : 64;
plan.n1 = n / plan.n2;
return plan;
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
// Otherwise we use a multi-upload Bluestein's
plan.four_step = true;
plan.bluestein_n = next_fast_n(2 * n - 1);
return plan;
}
for (int factor : factors) {
// Make sure the factor is a supported radix
if (radices_set.find(factor) == radices_set.end()) {
// We only support a single Rader factor currently
// TODO(alexbarron) investigate weirdness with large
// Rader sizes -- possibly a compiler issue?
if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) {
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
plan.bluestein_n = next_fast_n(2 * n - 1);
plan.stockham = plan_stockham_fft(plan.bluestein_n);
plan.rader = std::vector<int>(radices.size(), 0);
return plan;
}
// See if we can use Rader's algorithm to Stockham decompose n - 1
auto rader_factors = prime_factors(factor - 1);
int last_factor = -1;
for (int rf : rader_factors) {
// We don't nest Rader's algorithm so if `factor - 1`
// isn't Stockham decomposable we give up and do Bluestein's.
if (radices_set.find(rf) == radices_set.end()) {
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
plan.bluestein_n = next_fast_n(2 * n - 1);
plan.stockham = plan_stockham_fft(plan.bluestein_n);
plan.rader = std::vector<int>(radices.size(), 0);
return plan;
}
}
plan.rader = plan_stockham_fft(factor - 1);
plan.rader_n = factor;
remaining_n /= factor;
}
}
plan.stockham = plan_stockham_fft(remaining_n);
return plan;
}
int compute_elems_per_thread(FFTPlan plan) {
// Heuristics for selecting an efficient number
// of threads to use for a particular mixed-radix FFT.
auto n = plan.n;
std::vector<int> steps;
auto radices = supported_radices();
steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
std::set<int> used_radices;
for (int i = 0; i < steps.size(); i++) {
int radix = radices[i % radices.size()];
if (steps[i] > 0) {
used_radices.insert(radix);
}
}
// Manual tuning for 7/11/13
if (used_radices.find(7) != used_radices.end() &&
(used_radices.find(11) != used_radices.end() ||
used_radices.find(13) != used_radices.end())) {
return 7;
} else if (
used_radices.find(11) != used_radices.end() &&
used_radices.find(13) != used_radices.end()) {
return 11;
}
// TODO(alexbarron) Some really weird stuff is going on
// for certain `elems_per_thread` on large composite n.
// Possibly a compiler issue?
if (n == 3159)
return 13;
if (n == 3645)
return 5;
if (n == 3969)
return 7;
if (n == 1982)
return 5;
if (used_radices.size() == 1) {
return *(used_radices.begin());
}
if (used_radices.size() == 2) {
if (used_radices.find(11) != used_radices.end() ||
used_radices.find(13) != used_radices.end()) {
return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
}
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
return radix_vec[1];
}
// In all other cases use the second smallest radix.
std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
return radix_vec[1];
}
// Rader
int mod_exp(int x, int y, int n) {
int out = 1;
while (y) {
if (y & 1) {
out = out * x % n;
}
y >>= 1;
x = x * x % n;
}
return out;
}
int primitive_root(int n) {
auto factors = prime_factors(n - 1);
for (int r = 2; r < n - 1; r++) {
bool found = true;
for (int factor : factors) {
if (mod_exp(r, (n - 1) / factor, n) == 1) {
found = false;
break;
}
}
if (found) {
return r;
}
}
return -1;
}
std::tuple<array, array, array> compute_raders_constants(
int rader_n,
const Stream& s) {
int proot = primitive_root(rader_n);
// Fermat's little theorem
int inv = mod_exp(proot, rader_n - 2, rader_n);
std::vector<short> g_q(rader_n - 1);
std::vector<short> g_minus_q(rader_n - 1);
for (int i = 0; i < rader_n - 1; i++) {
g_q[i] = mod_exp(proot, i, rader_n);
g_minus_q[i] = mod_exp(inv, i, rader_n);
}
array g_q_arr(g_q.begin(), {rader_n - 1});
array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1});
std::vector<std::complex<float>> b_q(rader_n - 1);
for (int i = 0; i < rader_n - 1; i++) {
float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n;
b_q[i] = std::exp(std::complex<float>(0, pi_i));
}
array b_q_fft({rader_n - 1}, complex64, nullptr, {});
b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes()));
auto b_q_fft_ptr =
reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
std::ptrdiff_t item_size = b_q_fft.itemsize();
size_t fft_size = rader_n - 1;
// This FFT is always small (<4096, batch 1) so save some overhead
// and do it on the CPU
pocketfft::c2c(
/* shape= */ {fft_size},
/* stride_in= */ {item_size},
/* stride_out= */ {item_size},
/* axes= */ {0},
/* forward= */ true,
/* data_in= */ b_q.data(),
/* data_out= */ b_q_fft_ptr,
/* scale= */ 1.0f);
return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr);
}
// Bluestein
std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
// We need to calculate the Bluestein twiddle factors
// in double precision for the overall numerical stability
// of Bluestein's FFT algorithm to be acceptable.
//
// Metal doesn't support float64, so instead we
// manually implement the required operations on cpu.
//
// In numpy:
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
// w_q = np.fft.fft(1/w_k)
// return w_k, w_q
int length = 2 * n - 1;
std::vector<std::complex<float>> w_k_vec(n);
std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
for (int i = -n + 1; i < n; i++) {
double theta = pow(i, 2) * M_PI / (double)n;
w_q_vec[i + n - 1] = std::exp(std::complex<double>(0, theta));
if (i >= 0) {
w_k_vec[i] = std::exp(std::complex<double>(0, -theta));
}
}
array w_k({n}, complex64, nullptr, {});
w_k.set_data(allocator::malloc_or_wait(w_k.nbytes()));
std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
array w_q({bluestein_n}, complex64, nullptr, {});
w_q.set_data(allocator::malloc_or_wait(w_q.nbytes()));
auto w_q_ptr =
reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
std::ptrdiff_t item_size = w_q.itemsize();
size_t fft_size = bluestein_n;
pocketfft::c2c(
/* shape= */ {fft_size},
/* stride_in= */ {item_size},
/* stride_out= */ {item_size},
/* axes= */ {0},
/* forward= */ true,
/* data_in= */ w_q_vec.data(),
/* data_out= */ w_q_ptr,
/* scale= */ 1.0f);
return std::make_tuple(w_k, w_q);
}
void multi_upload_bluestein_fft(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
const Stream& s) {
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
// algorithm
int n = inverse ? out.shape(axis) : in.shape(axis);
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
// Broadcast w_q and w_k to the batch size
std::vector<size_t> b_strides(in.ndim(), 0);
b_strides[axis] = 1;
array w_k_broadcast({}, complex64, nullptr, {});
array w_q_broadcast({}, complex64, nullptr, {});
w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
auto temp_shape = inverse ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
array temp1(temp_shape, complex64, nullptr, {});
if (real && !inverse) {
// Convert float32->complex64
copy_gpu(in, temp, CopyType::General, s);
} else if (real && inverse) {
int back_offset = n % 2 == 0 ? 2 : 1;
auto slice_shape = in.shape();
slice_shape[axis] -= back_offset;
array slice_temp(slice_shape, complex64, nullptr, {});
array conj_temp(in.shape(), complex64, nullptr, {});
copies.push_back(slice_temp);
copies.push_back(conj_temp);
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
rstarts[axis] = in.shape(axis) - back_offset;
rstrides[axis] = -1;
unary_op_gpu({in}, conj_temp, "Conjugate", s);
slice_gpu(in, slice_temp, rstarts, rstrides, s);
concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
} else if (inverse) {
unary_op_gpu({in}, temp, "Conjugate", s);
} else {
temp.copy_shared_buffer(in);
}
binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
std::vector<std::pair<int, int>> pads;
auto padded_shape = out.shape();
padded_shape[axis] = plan.bluestein_n;
array pad_temp(padded_shape, complex64, nullptr, {});
pad_gpu(temp1, array(complex64_t{0.0f, 0.0f}), pad_temp, {(int)axis}, {0}, s);
array pad_temp1(padded_shape, complex64, nullptr, {});
fft_op(
pad_temp,
pad_temp1,
axis,
/*inverse=*/false,
/*real=*/false,
FourStepParams(),
/*inplace=*/false,
s);
binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
fft_op(
pad_temp,
pad_temp1,
axis,
/* inverse= */ true,
/* real= */ false,
FourStepParams(),
/*inplace=*/true,
s);
int offset = plan.bluestein_n - (2 * n - 1);
std::vector<int> starts(in.ndim(), 0);
std::vector<int> strides(in.ndim(), 1);
starts[axis] = plan.bluestein_n - offset - n;
slice_gpu(pad_temp1, temp, starts, strides, s);
binary_op_gpu_inplace({temp, w_k_broadcast}, temp1, "Multiply", s);
if (real && !inverse) {
std::vector<int> rstarts(in.ndim(), 0);
std::vector<int> rstrides(in.ndim(), 1);
slice_gpu(temp1, out, rstarts, strides, s);
} else if (real && inverse) {
std::vector<size_t> b_strides(in.ndim(), 0);
auto inv_n = array({1.0f / n}, {1}, float32);
array temp_float(out.shape(), out.dtype(), nullptr, {});
copies.push_back(temp_float);
copies.push_back(inv_n);
copy_gpu(temp1, temp_float, CopyType::General, s);
binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
} else if (inverse) {
auto inv_n = array({1.0f / n}, {1}, complex64);
unary_op_gpu({temp1}, temp, "Conjugate", s);
binary_op_gpu({temp, inv_n}, out, "Multiply", s);
copies.push_back(inv_n);
} else {
out.copy_shared_buffer(temp1);
}
copies.push_back(w_k);
copies.push_back(w_q);
copies.push_back(w_k_broadcast);
copies.push_back(w_q_broadcast);
copies.push_back(temp);
copies.push_back(temp1);
copies.push_back(pad_temp);
copies.push_back(pad_temp1);
}
void four_step_fft(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
FFTPlan& plan,
std::vector<array> copies,
const Stream& s) {
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
auto& d = metal::device(s.device);
if (plan.bluestein_n == -1) {
// Fast no transpose implementation for powers of 2.
FourStepParams four_step_params = {
/* required= */ true, /* first_step= */ true, plan.n1, plan.n2};
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
array temp(temp_shape, complex64, nullptr, {});
fft_op(
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
four_step_params.first_step = false;
fft_op(
temp, out, axis, inverse, real, four_step_params, /*inplace=*/false, s);
copies.push_back(temp);
} else {
multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
}
}
auto& in = inputs[0];
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
const FourStepParams four_step_params,
bool inplace,
const Stream& s) {
auto& d = metal::device(s.device);
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
if (n == 1) {
out.copy_shared_buffer(in);
return;
if (axes_.size() == 0 || axes_.size() > 1 || inverse_ ||
in.dtype() != complex64 || out.dtype() != complex64) {
// Could also fallback to CPU implementation here.
throw std::runtime_error(
"GPU FFT is only implemented for 1D, forward, complex FFTs.");
}
if (four_step_params.required) {
// Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows
n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2;
size_t n = in.shape(axes_[0]);
if (!is_power_of_2(n) || n > 2048 || n < 4) {
throw std::runtime_error(
"GPU FFT is only implemented for the powers of 2 from 4 -> 2048");
}
// Make sure that the array is contiguous and has stride 1 in the FFT dim
std::vector<array> copies;
auto check_input = [&axis, &copies, &s](const array& x) {
auto check_input = [this, &copies, &s](const array& x) {
// TODO: Pass the strides to the kernel so
// we can avoid the copy when x is not contiguous.
bool no_copy = x.strides()[axis] == 1 &&
(x.flags().row_contiguous || x.flags().col_contiguous);
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().row_contiguous ||
x.flags().col_contiguous;
if (no_copy) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
std::vector<size_t> strides;
size_t cur_stride = x.shape(axis);
for (int a = 0; a < x.ndim(); a++) {
if (a == axis) {
size_t cur_stride = x.shape(axes_[0]);
for (int axis = 0; axis < x.ndim(); axis++) {
if (axis == axes_[0]) {
strides.push_back(1);
} else {
strides.push_back(cur_stride);
cur_stride *= x.shape(a);
cur_stride *= x.shape(axis);
}
}
auto flags = x.flags();
auto [data_size, is_row_contiguous, is_col_contiguous] =
check_contiguity(x.shape(), strides);
flags.col_contiguous = is_row_contiguous;
flags.row_contiguous = is_col_contiguous;
flags.contiguous = data_size == x_copy.size();
size_t f_stride = 1;
size_t b_stride = 1;
flags.col_contiguous = true;
flags.row_contiguous = true;
for (int i = 0, ri = x.ndim() - 1; i < x.ndim(); ++i, --ri) {
flags.col_contiguous &= (strides[i] == f_stride || x.shape(i) == 1);
f_stride *= x.shape(i);
flags.row_contiguous &= (strides[ri] == b_stride || x.shape(ri) == 1);
b_stride *= x.shape(ri);
}
// This is probably over-conservative
flags.contiguous = false;
x_copy.set_data(
allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags);
allocator::malloc_or_wait(x.nbytes()), x.data_size(), strides, flags);
copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& in_contiguous = check_input(in);
// real to complex: n -> (n/2)+1
// complex to real: (n/2)+1 -> n
auto out_strides = in_contiguous.strides();
size_t out_data_size = in_contiguous.data_size();
if (in.shape(axis) != out.shape(axis)) {
for (int i = 0; i < out_strides.size(); i++) {
if (out_strides[i] != 1) {
out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis);
}
}
out_data_size = out_data_size / in.shape(axis) * out.shape(axis);
}
auto plan = plan_fft(n);
if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
return;
}
const array& in_contiguous = check_input(inputs[0]);
// TODO: allow donation here
if (!inplace) {
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
out_data_size,
out_strides,
in_contiguous.flags());
}
out.set_data(
allocator::malloc_or_wait(out.nbytes()),
in_contiguous.data_size(),
in_contiguous.strides(),
in_contiguous.flags());
auto radices = supported_radices();
int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n;
// Setup function constants
bool power_of_2 = is_power_of_2(fft_size);
auto make_int = [](int* a, int i) {
return std::make_tuple(a, MTL::DataType::DataTypeInt, i);
};
auto make_bool = [](bool* a, int i) {
return std::make_tuple(a, MTL::DataType::DataTypeBool, i);
};
std::vector<MTLFC> func_consts = {
make_bool(&inverse, 0), make_bool(&power_of_2, 1)};
// Start of radix/rader step constants
int index = 4;
for (int i = 0; i < plan.stockham.size(); i++) {
func_consts.push_back(make_int(&plan.stockham[i], index));
index += 1;
}
for (int i = 0; i < plan.rader.size(); i++) {
func_consts.push_back(make_int(&plan.rader[i], index));
index += 1;
}
int elems_per_thread = compute_elems_per_thread(plan);
func_consts.push_back(make_int(&elems_per_thread, 2));
int rader_m = n / plan.rader_n;
func_consts.push_back(make_int(&rader_m, 3));
// The overall number of FFTs we're going to compute for this input
int size = out.dtype() == float32 ? out.size() : in.size();
if (real && inverse && four_step_params.required) {
size = out.size();
}
int total_batch_size = size / n;
int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread;
// We batch among threadgroups for improved efficiency when n is small
int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1);
if (four_step_params.required) {
// Require a threadgroup batch size of at least 4 for four step FFT
// so we can coalesce the memory accesses.
threadgroup_batch_size =
std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH);
}
int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size);
// FFTs up to 2^20 are currently supported
assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE);
// ceil divide
int batch_size =
(total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size;
if (real && !four_step_params.required) {
// We can perform 2 RFFTs at once so the batch size is halved.
batch_size = (batch_size + 2 - 1) / 2;
}
int out_buffer_size = out.size();
// We use n / 4 threads by default since radix-4
// is the largest single threaded radix butterfly
// we currently implement.
size_t m = n / 4;
size_t batch = in.size() / in.shape(axes_[0]);
auto& compute_encoder = d.get_command_encoder(s.index);
auto in_type_str = in.dtype() == float32 ? "float" : "float2";
auto out_type_str = out.dtype() == float32 ? "float" : "float2";
// Only required by four step
int step = -1;
{
std::ostringstream kname;
std::string inv_string = inverse ? "true" : "false";
std::string real_string = real ? "true" : "false";
std::string func_name;
if (plan.bluestein_n > 0) {
kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
<< in_type_str << "_" << out_type_str;
func_name = "bluestein_fft";
} else if (plan.rader_n > 1) {
kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str;
func_name = "rader_fft";
} else if (four_step_params.required) {
step = four_step_params.first_step ? 0 : 1;
kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
<< "_" << out_type_str << "_" << step << "_" << real_string;
func_name = "four_step_fft";
} else {
kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
<< out_type_str;
func_name = "fft";
}
std::string base_name = kname.str();
// We use a specialized kernel for each FFT size
kname << "_n" << fft_size << "_inv_" << inverse;
std::string hash_name = kname.str();
auto template_def = func_name == "four_step_fft" ? get_template_definition(
base_name,
func_name,
threadgroup_mem_size,
in_type_str,
out_type_str,
step,
real)
: get_template_definition(
base_name,
func_name,
threadgroup_mem_size,
in_type_str,
out_type_str);
auto kernel =
get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
kname << "fft_" << n;
auto kernel = d.get_kernel(kname.str());
bool donated = in.data_shared_ptr() == nullptr;
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in_contiguous, 0);
compute_encoder.set_output_array(out, 1);
if (plan.bluestein_n > 0) {
// Precomputed twiddle factors for Bluestein's
auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
copies.push_back(w_q);
copies.push_back(w_k);
compute_encoder.set_input_array(w_q, 2); // w_q
compute_encoder.set_input_array(w_k, 3); // w_k
compute_encoder->setBytes(&n, sizeof(int), 4);
compute_encoder->setBytes(&plan.bluestein_n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
} else if (plan.rader_n > 1) {
auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
copies.push_back(b_q);
copies.push_back(g_q);
copies.push_back(g_minus_q);
compute_encoder.set_input_array(b_q, 2);
compute_encoder.set_input_array(g_q, 3);
compute_encoder.set_input_array(g_minus_q, 4);
compute_encoder->setBytes(&n, sizeof(int), 5);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 6);
compute_encoder->setBytes(&plan.rader_n, sizeof(int), 7);
} else if (four_step_params.required) {
compute_encoder->setBytes(&four_step_params.n1, sizeof(int), 2);
compute_encoder->setBytes(&four_step_params.n2, sizeof(int), 3);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 4);
} else {
compute_encoder->setBytes(&n, sizeof(int), 2);
compute_encoder->setBytes(&total_batch_size, sizeof(int), 3);
}
auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
auto grid_dims =
MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
compute_encoder->dispatchThreads(grid_dims, group_dims);
auto group_dims = MTL::Size(1, m, 1);
auto grid_dims = MTL::Size(batch, m, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
void fft_op(
const array& in,
array& out,
size_t axis,
bool inverse,
bool real,
bool inplace,
const Stream& s) {
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
}
void nd_fft_op(
const array& in,
array& out,
const std::vector<size_t>& axes,
bool inverse,
bool real,
const Stream& s) {
// Perform ND FFT on GPU as a series of 1D FFTs
auto temp_shape = inverse ? in.shape() : out.shape();
array temp1(temp_shape, complex64, nullptr, {});
array temp2(temp_shape, complex64, nullptr, {});
std::vector<array> temp_arrs = {temp1, temp2};
for (int i = axes.size() - 1; i >= 0; i--) {
int reverse_index = axes.size() - i - 1;
// For 5D and above, we don't want to reallocate our two temporary arrays
bool inplace = reverse_index >= 3 && i != 0;
// Opposite order for fft vs ifft
int index = inverse ? reverse_index : i;
size_t axis = axes[index];
// Mirror np.fft.(i)rfftn and perform a real transform
// only on the final axis.
bool step_real = (real && index == axes.size() - 1);
int step_shape = inverse ? out.shape(axis) : in.shape(axis);
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
}
auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler(
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
}
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& in = inputs[0];
if (axes_.size() > 1) {
nd_fft_op(in, out, axes_, inverse_, real_, s);
} else {
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
}
}
} // namespace mlx::core

View File

@@ -1,203 +0,0 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/hadamard.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
constexpr int MAX_HADAMARD_THREADS_PER_GROUP = 256;
constexpr int MAX_HADAMARD_BYTES = 32768; // 32KB
std::string gen_hadamard_codelet(int m) {
// Generate a O(m^2) hadamard codelet for a given M
// using the hadamard matrices above
//
// e.g. m = 2
// METAL_FUNC void hadamard_m(thread float *x) {
// float tmp[2];
// tmp[0] = + x[0] + x[1];
// tmp[1] = + x[0] - x[1];
// for (int i = 0; i < 2; i++) { x[i] = tmp[i]; }
// }
//
auto h_matrices = hadamard_matrices();
auto& matrix = h_matrices[m];
std::ostringstream source;
source << "METAL_FUNC void hadamard_radix_m(thread float *x) {" << std::endl;
if (m == 1) {
source << "}" << std::endl;
return source.str();
}
source << " float tmp[" << m << "];" << std::endl;
auto start = 1;
auto end = matrix.find('\n', start);
int index = 0;
while (end != std::string_view::npos) {
source << " tmp[" << index << "] = ";
auto row = matrix.substr(start, end - start);
for (int i = 0; i < row.length(); i++) {
source << " " << row[i] << " x[" << i << "]";
}
source << ";" << std::endl;
start = end + 1;
end = matrix.find('\n', start);
index++;
}
source << " for (int i = 0; i < " << m << "; i++) { x[i] = tmp[i]; }"
<< std::endl;
source << "}" << std::endl;
return source.str();
}
void launch_hadamard(
const array& in,
array& out,
int batch_size,
int threads_per,
const std::string kernel_name,
float scale,
const Stream& s) {
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
auto kernel = d.get_kernel(kernel_name, lib);
assert(threads_per <= kernel->maxTotalThreadsPerThreadgroup());
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder->setBytes(&scale, sizeof(float), 2);
MTL::Size group_dims = MTL::Size(1, threads_per, 1);
MTL::Size grid_dims = MTL::Size(batch_size, threads_per, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream();
auto& in = inputs[0];
std::vector<array> copies;
// Only support the last axis for now
int axis = in.ndim() - 1;
auto check_input = [&copies, &s](const array& x) {
// TODO(alexbarron) pass strides to kernel to relax this constraint
bool no_copy = x.flags().row_contiguous;
if (no_copy) {
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
}
};
const array& in_contiguous = check_input(in);
if (in_contiguous.is_donatable()) {
out.move_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto [n, m] = decompose_hadamard(in.shape(axis));
if (n * (int)size_of(in.dtype()) > MAX_HADAMARD_BYTES) {
throw std::invalid_argument(
"[hadamard] For n = m*2^k, 2^k > 8192 for FP32 or 2^k > 16384 for FP16/BF16 NYI");
}
int max_radix = std::min(n, 16);
// Use read_width 2 for m = 28 to avoid register spilling
int read_width = (n == 2 || m == 28) ? 2 : 4;
std::ostringstream kname;
kname << "hadamard_" << n * m << "_" << type_to_name(out);
auto kernel_name = kname.str();
auto& d = metal::device(s.device);
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto codelet = gen_hadamard_codelet(m);
kernel_source << metal::utils() << codelet << metal::hadamard();
kernel_source << get_template_definition(
"n" + kernel_name,
"hadamard_n",
get_type_string(in.dtype()),
n,
max_radix,
read_width);
kernel_source << get_template_definition(
"m" + kernel_name,
"hadamard_m",
get_type_string(in.dtype()),
n,
m,
read_width);
lib = d.get_library(lib_name, kernel_source.str());
}
int batch_size = in.size() / n;
int threads_per = n / max_radix;
if (m > 1) {
// When m is greater than 1, we decompose the
// computation into two uploads to the GPU:
//
// e.g. len(x) = 12*4 = 48, m = 12, n = 4
//
// y = h48 @ x
//
// Upload 1:
// tmp = a.reshape(12, 4) @ h4
//
// Upload 2:
// y = h12 @ tmp
array temp(in.shape(), in.dtype(), nullptr, {});
temp.set_data(allocator::malloc_or_wait(temp.nbytes()));
copies.push_back(temp);
launch_hadamard(
in_contiguous,
temp,
batch_size,
threads_per,
"n" + kernel_name,
1.0,
s);
// Metal sometimes reports 256 max threads per group for hadamard_m kernel
threads_per = std::min(n / read_width, MAX_HADAMARD_THREADS_PER_GROUP);
batch_size = in.size() / m / read_width / threads_per;
launch_hadamard(
temp, out, batch_size, threads_per, "m" + kernel_name, scale_, s);
} else {
launch_hadamard(
in_contiguous,
out,
batch_size,
threads_per,
"n" + kernel_name,
scale_,
s);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

View File

@@ -293,18 +293,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
out.shape().data(), out.shape().size() * sizeof(int), 3);
compute_encoder->setBytes(
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
size_t out_ndim = out.ndim();
compute_encoder->setBytes(&out_ndim, sizeof(out_ndim), 5);
if (upd_ndim <= 1) {
// Placeholder so Metal doesn't compalain
int shape_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 6);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 6);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 7);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 8);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
// Set index buffers
for (int i = 0; i < nidx; ++i) {

View File

@@ -0,0 +1,87 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,98 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view binary_two_kernels = R"(
template [[host_name("ss{0}")]] [[kernel]]
void binary_ss<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vs{0}")]] [[kernel]]
void binary_vs<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("sv{0}")]] [[kernel]]
void binary_sv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("vv{0}")]] [[kernel]]
void binary_vv<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g4{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 4>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5{0}")]] [[kernel]] void
binary_g_nd<{1}, {2}, {3}, 5>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1{0}")]] [[kernel]] void
binary_g_nd1<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]);
template [[host_name("g2{0}")]] [[kernel]] void
binary_g_nd2<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3{0}")]] [[kernel]] void
binary_g_nd3<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("gn{0}")]] [[kernel]]
void binary_g<{1}, {2}, {3}>(
device const {1}* a,
device const {1}* b,
device {2}* c,
device {2}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -1,25 +0,0 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view gemv_masked_kernel = R"(
template [[host_name("{name}")]] [[kernel]] void
gemv_{trans}masked<{itype}, {outm_t}, {opm_t}, {bm}, {bn}, {sm}, {sn}, {tm}, {tn}, {nc}>(
const device {itype}* mat [[buffer(0)]],
const device {itype}* in_vec [[buffer(1)]],
device {itype}* out_vec [[buffer(3)]],
const constant int& in_vec_size [[buffer(4)]],
const constant int& out_vec_size [[buffer(5)]],
const constant int& marix_ld [[buffer(6)]],
const constant int& batch_ndim [[buffer(9)]],
const constant int* batch_shape [[buffer(10)]],
const constant size_t* vector_batch_stride [[buffer(11)]],
const constant size_t* matrix_batch_stride [[buffer(12)]],
const device {outm_t}* out_mask [[buffer(20)]],
const device {opm_t}* mat_mask [[buffer(21)]],
const device {opm_t}* vec_mask [[buffer(22)]],
const constant int* mask_strides [[buffer(23)]],
const constant size_t* mask_batch_strides [[buffer(24)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]);
)";

View File

@@ -17,9 +17,6 @@ const char* unary();
const char* binary();
const char* binary_two();
const char* copy();
const char* fft();
const char* hadamard();
const char* quantized();
const char* ternary();
const char* scan();
const char* softmax();
@@ -33,6 +30,5 @@ const char* steel_gemm_splitk();
const char* conv();
const char* steel_conv();
const char* steel_conv_general();
const char* gemv_masked();
} // namespace mlx::core::metal

View File

@@ -38,24 +38,12 @@ constexpr std::string_view scatter_kernels = R"(
device mlx_atomic<{1}>* out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& out_ndim [[buffer(5)]],
const constant int* upd_shape [[buffer(6)]],
const constant size_t& upd_ndim [[buffer(7)]],
const constant size_t& upd_size [[buffer(8)]],
const constant size_t& upd_size [[buffer(5)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
updates,
out,
out_shape,
out_strides,
out_ndim,
upd_shape,
upd_ndim,
upd_size,
idx_buffers,
gid);
updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
}}
[[kernel]] void scatter{0}_{4}(

View File

@@ -0,0 +1,81 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view block_sort_kernels = R"(
template [[host_name("carg_{0}")]] [[kernel]] void
block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("ncarg_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("c_{0}")]] [[kernel]] void
block_sort<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& stride_segment_axis [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("nc_{0}")]] [[kernel]] void
block_sort_nc<{1}, {2}, false, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {2}* out [[buffer(1)]],
const constant int& size_sorted_axis [[buffer(2)]],
const constant int& stride_sorted_axis [[buffer(3)]],
const constant int& nc_dim [[buffer(4)]],
const device int* nc_shape [[buffer(5)]],
const device size_t* nc_strides [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";
constexpr std::string_view multiblock_sort_kernels = R"(
template [[host_name("sort_{0}")]] [[kernel]] void
mb_block_sort<{1}, {2}, true, {3}, {4}>(
const device {1}* inp [[buffer(0)]],
device {1}* out_vals [[buffer(1)]],
device {2}* out_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& stride_sorted_axis [[buffer(4)]],
const constant int& nc_dim [[buffer(5)]],
const device int* nc_shape [[buffer(6)]],
const device size_t* nc_strides [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
template [[host_name("partition_{0}")]] [[kernel]] void
mb_block_partition<{1}, {2}, true, {3}, {4}>(
device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals [[buffer(1)]],
const device {2}* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]);
template [[host_name("merge_{0}")]] [[kernel]] void
mb_block_merge<{1}, {2}, true, {3}, {4}>(
const device {2}* block_partitions [[buffer(0)]],
const device {1}* dev_vals_in [[buffer(1)]],
const device {2}* dev_idxs_in [[buffer(2)]],
device {1}* dev_vals_out [[buffer(3)]],
device {2}* dev_idxs_out [[buffer(4)]],
const constant int& size_sorted_axis [[buffer(5)]],
const constant int& merge_tiles [[buffer(6)]],
const constant int& num_tiles [[buffer(7)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]);
)";

View File

@@ -0,0 +1,80 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view ternary_kernels = R"(
template [[host_name("v_{0}")]] [[kernel]] void ternary_v<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
uint index [[thread_position_in_grid]]);
template [[host_name("g_{0}")]] [[kernel]] void ternary_g<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const size_t* c_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g1_{0}")]] [[kernel]] void
ternary_g_nd1<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t& a_strides,
constant const size_t& b_strides,
constant const size_t& c_strides,
uint index [[thread_position_in_grid]]);
template [[host_name("g2_{0}")]] [[kernel]] void
ternary_g_nd2<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
constant const size_t c_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]);
template [[host_name("g3_{0}")]] [[kernel]] void
ternary_g_nd3<{1}, {2}>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
constant const size_t c_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g4_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 4>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[4],
constant const size_t a_strides[4],
constant const size_t b_strides[4],
constant const size_t c_strides[4],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
template [[host_name("g5_{0}")]] [[kernel]] void
ternary_g_nd<{1}, {2}, 5>(
device const bool* a,
device const {1}* b,
device const {1}* c,
device {1}* d,
constant const int shape[5],
constant const size_t a_strides[5],
constant const size_t b_strides[5],
constant const size_t c_strides[5],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]);
)";

View File

@@ -0,0 +1,16 @@
// Copyright © 2024 Apple Inc.
constexpr std::string_view unary_kernels = R"(
template [[host_name("v{0}")]] [[kernel]] void unary_v<{1}, {2}>(
device const {1}* in,
device {1}* out,
uint index [[thread_position_in_grid]]);
template [[host_name("g{0}")]] [[kernel]] void unary_g<{1}, {2}>(
device const {1}* in,
device {1}* out,
device const int* in_shape,
device const size_t* in_strides,
device const int& ndim,
uint index [[thread_position_in_grid]]);
)";

View File

@@ -1,16 +1,20 @@
// Copyright © 2024 Apple Inc.
#include <map>
#include <fmt/format.h>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/metal/jit/arange.h"
#include "mlx/backend/metal/jit/binary.h"
#include "mlx/backend/metal/jit/binary_two.h"
#include "mlx/backend/metal/jit/copy.h"
#include "mlx/backend/metal/jit/gemv_masked.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/reduce.h"
#include "mlx/backend/metal/jit/scan.h"
#include "mlx/backend/metal/jit/softmax.h"
#include "mlx/backend/metal/jit/sort.h"
#include "mlx/backend/metal/jit/steel_conv.h"
#include "mlx/backend/metal/jit/steel_gemm.h"
#include "mlx/backend/metal/jit/ternary.h"
#include "mlx/backend/metal/jit/unary.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/utils.h"
@@ -43,81 +47,38 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype out_type,
const std::string op) {
const array& out) {
std::string lib_name = kernel_name.substr(1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto u_def = get_template_definition(
"v" + lib_name, "unary_v", get_type_string(out_type), op);
auto u2_def = get_template_definition(
"v2" + lib_name, "unary_v2", get_type_string(out_type), op);
auto g_def = get_template_definition(
"g" + lib_name, "unary_g", get_type_string(out_type), op);
kernel_source << metal::utils() << metal::unary_ops() << metal::unary()
<< u_def << u2_def << g_def;
<< fmt::format(
unary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
void add_binary_kernels(
const std::string lib_name,
Dtype in_type,
Dtype out_type,
const std::string op,
std::ostringstream& kernel_source) {
const std::map<std::string, std::string> kernel_types = {
{"ss", "binary_ss"},
{"vs", "binary_vs"},
{"sv", "binary_sv"},
{"vv", "binary_vv"},
{"vs2", "binary_vs2"},
{"sv2", "binary_sv2"},
{"vv2", "binary_vv2"},
{"g1", "binary_g_nd1"},
{"g2", "binary_g_nd2"},
{"g3", "binary_g_nd3"},
{"g4", "binary_g_nd"},
{"g5", "binary_g_nd"},
{"gn", "binary_g"},
};
for (auto [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op,
dim);
} else {
template_def = get_template_definition(
name + lib_name,
func,
get_type_string(in_type),
get_type_string(out_type),
op);
}
kernel_source << template_def;
}
}
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops() << metal::binary();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
kernel_source << metal::utils() << metal::binary_ops() << metal::binary()
<< fmt::format(
binary_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -126,16 +87,20 @@ MTL::ComputePipelineState* get_binary_kernel(
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op) {
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(2);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::binary_ops()
<< metal::binary_two();
add_binary_kernels(lib_name, in_type, out_type, op, kernel_source);
<< metal::binary_two()
<< fmt::format(
binary_two_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -144,35 +109,17 @@ MTL::ComputePipelineState* get_binary_two_kernel(
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype type,
const std::string op) {
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
const std::map<std::string, std::string> kernel_types = {
{"v", "ternary_v"},
{"v2", "ternary_v2"},
{"g", "ternary_g"},
{"g1", "ternary_g_nd1"},
{"g2", "ternary_g_nd2"},
{"g3", "ternary_g_nd3"},
{"g4", "ternary_g_nd"},
{"g5", "ternary_g_nd"},
};
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary();
for (auto [name, func] : kernel_types) {
std::string template_def;
if (name == "g4" || name == "g5") {
int dim = std::stoi(name.substr(1));
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op, dim);
} else {
template_def = get_template_definition(
name + "_" + lib_name, func, get_type_string(type), op);
}
kernel_source << template_def;
}
kernel_source << metal::utils() << metal::ternary_ops() << metal::ternary()
<< fmt::format(
ternary_kernels,
lib_name,
get_type_string(out.dtype()),
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -223,14 +170,11 @@ MTL::ComputePipelineState* get_scan_kernel(
const std::string& kernel_name,
bool reverse,
bool inclusive,
const std::string& reduce_type,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_name = "Cum" + reduce_type;
op_name[3] = toupper(op_name[3]);
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::scan()
<< fmt::format(
@@ -238,7 +182,7 @@ MTL::ComputePipelineState* get_scan_kernel(
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_name,
op_name(out),
inclusive,
reverse);
lib = d.get_library(lib_name, kernel_source.str());
@@ -257,29 +201,14 @@ MTL::ComputePipelineState* get_sort_kernel(
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto in_type = get_type_string(in.dtype());
auto out_type = get_type_string(out.dtype());
kernel_source << metal::utils() << metal::sort();
for (bool is_argsort : {true, false}) {
std::string bool_string = is_argsort ? "true" : "false";
std::string func_string = is_argsort ? "carg_" : "c_";
kernel_source << get_template_definition(
func_string + lib_name,
"block_sort",
in_type,
out_type,
bool_string,
bn,
tn);
kernel_source << get_template_definition(
"n" + func_string + lib_name,
"block_sort_nc",
in_type,
out_type,
bool_string,
bn,
tn);
}
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
block_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
bn,
tn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -296,21 +225,14 @@ MTL::ComputePipelineState* get_mb_sort_kernel(
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::sort();
std::vector<std::pair<std::string, std::string>> kernel_types = {
{"sort_", "mb_block_sort"},
{"partition_", "mb_block_partition"},
{"merge_", "mb_block_merge"}};
for (auto [name, func] : kernel_types) {
kernel_source << get_template_definition(
name + lib_name,
func,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
"true",
bn,
tn);
}
kernel_source << metal::utils() << metal::sort()
<< fmt::format(
multiblock_sort_kernels,
lib_name,
get_type_string(in.dtype()),
get_type_string(idx.dtype()),
bn,
tn);
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -337,14 +259,11 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& op_name,
const array& in,
const array& out) {
std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1);
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::string op_type = op_name;
op_type[0] = std::toupper(op_name[0]);
bool non_atomic = out.dtype() == int64 || out.dtype() == uint64;
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::reduce_utils() << metal::reduce()
@@ -354,7 +273,7 @@ MTL::ComputePipelineState* get_reduce_kernel(
lib_name,
get_type_string(in.dtype()),
get_type_string(out.dtype()),
op_type);
op_name(out));
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
@@ -503,49 +422,6 @@ MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn,
bool contiguous) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
auto out_mask_type = mask_out.has_value()
? get_type_string((*mask_out).dtype())
: "nomask_t";
auto op_mask_type =
mask_op.has_value() ? get_type_string((*mask_op).dtype()) : "nomask_t";
kernel_source << metal::utils() << metal::gemv_masked()
<< fmt::format(
gemv_masked_kernel,
"name"_a = lib_name,
"itype"_a = get_type_string(out.dtype()),
"outm_t"_a = out_mask_type,
"opm_t"_a = op_mask_type,
"bm"_a = bm,
"bn"_a = bn,
"sm"_a = sm,
"sn"_a = sn,
"tm"_a = tm,
"tn"_a = tn,
"trans"_a = transpose_mat ? "t_" : "",
"nc"_a = contiguous ? "0" : "1");
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_steel_conv_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -607,36 +483,4 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
return d.get_kernel(kernel_name, lib);
}
MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
std::string kernel_string;
kernel_source << metal::fft() << template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib, hash_name, func_consts);
}
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def) {
const auto& lib_name = kernel_name;
auto lib = d.get_library(lib_name);
if (lib == nullptr) {
std::ostringstream kernel_source;
kernel_source << metal::utils() << metal::gemm() << metal::quantized()
<< template_def;
lib = d.get_library(lib_name, kernel_source.str());
}
return d.get_kernel(kernel_name, lib);
}
} // namespace mlx::core

View File

@@ -1,7 +1,5 @@
// Copyright © 2024 Apple Inc.
#include <fmt/format.h>
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
@@ -15,28 +13,24 @@ MTL::ComputePipelineState* get_arange_kernel(
MTL::ComputePipelineState* get_unary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype out_type,
const std::string op);
const array& out);
MTL::ComputePipelineState* get_binary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
const array& in,
const array& out);
MTL::ComputePipelineState* get_binary_two_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype in_type,
Dtype out_type,
const std::string op);
const array& in,
const array& out);
MTL::ComputePipelineState* get_ternary_kernel(
metal::Device& d,
const std::string& kernel_name,
Dtype type,
const std::string op);
const array& out);
MTL::ComputePipelineState* get_copy_kernel(
metal::Device& d,
@@ -55,7 +49,6 @@ MTL::ComputePipelineState* get_scan_kernel(
const std::string& kernel_name,
bool reverse,
bool inclusive,
const std::string& reduce_type,
const array& in,
const array& out);
@@ -83,7 +76,6 @@ MTL::ComputePipelineState* get_reduce_init_kernel(
MTL::ComputePipelineState* get_reduce_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& op_name,
const array& in,
const array& out);
@@ -151,21 +143,6 @@ MTL::ComputePipelineState* get_steel_conv_kernel(
int n_channel_specialization,
bool small_filter);
MTL::ComputePipelineState* get_gemv_masked_kernel(
metal::Device& d,
const std::string& kernel_name,
const array& out,
const std::optional<array>& mask_out,
const std::optional<array>& mask_op,
bool transpose_mat,
int bm,
int bn,
int sm,
int sn,
int tm,
int tn,
bool contiguous);
MTL::ComputePipelineState* get_steel_conv_general_kernel(
metal::Device& d,
const std::string& kernel_name,
@@ -176,38 +153,4 @@ MTL::ComputePipelineState* get_steel_conv_general_kernel(
int wm,
int wn);
MTL::ComputePipelineState* get_fft_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& hash_name,
const metal::MTLFCList& func_consts,
const std::string& template_def);
MTL::ComputePipelineState* get_quantized_kernel(
metal::Device& d,
const std::string& kernel_name,
const std::string& template_def);
// Create a GPU kernel template definition for JIT compilation
template <typename... Args>
std::string
get_template_definition(std::string name, std::string func, Args... args) {
std::ostringstream s;
s << func << "<";
bool first = true;
auto add_arg = [&s, &first](const auto& arg) {
if (!first) {
s << ", ";
}
first = false;
s << arg;
};
(add_arg(args), ...);
s << ">";
std::string base_string = R"(
template [[host_name("{0}")]] [[kernel]] decltype({1}) {1};
)";
return fmt::format(base_string, name, s.str());
}
} // namespace mlx::core

View File

@@ -1,15 +1,66 @@
set(
BASE_HEADERS
HEADERS
bf16.h
bf16_math.h
complex.h
defines.h
expm1f.h
utils.h
steel/conv/params.h
)
set(
KERNELS
"arg_reduce"
"conv"
"fft"
"gemv"
"quantized"
"random"
"rms_norm"
"layer_norm"
"rope"
"scaled_dot_product_attention"
)
if (NOT MLX_METAL_JIT)
set(
KERNELS
${KERNELS}
"arange"
"binary"
"binary_two"
"unary"
"ternary"
"copy"
"softmax"
"sort"
"scan"
"reduce"
)
set(
HEADERS
${HEADERS}
atomic.h
arange.h
unary_ops.h
unary.h
binary_ops.h
binary.h
ternary.h
copy.h
softmax.h
sort.h
scan.h
reduction/ops.h
reduction/reduce_init.h
reduction/reduce_all.h
reduction/reduce_col.h
reduction/reduce_row.h
)
endif()
function(build_kernel_base TARGET SRCFILE DEPS)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math)
set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
if(MLX_METAL_DEBUG)
set(METAL_FLAGS ${METAL_FLAGS}
-gline-tables-only
@@ -21,7 +72,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
DEPENDS ${SRCFILE} ${DEPS}
OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air"
VERBATIM
@@ -30,100 +81,49 @@ endfunction(build_kernel_base)
function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE)
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS}")
endfunction(build_kernel)
build_kernel(arg_reduce)
build_kernel(conv steel/conv/params.h)
build_kernel(gemv steel/utils.h)
build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(
scaled_dot_product_attention
scaled_dot_product_attention_params.h
steel/defines.h
steel/gemm/transforms.h
steel/utils.h
)
set(
STEEL_HEADERS
steel/defines.h
steel/utils.h
steel/conv/conv.h
steel/conv/loader.h
steel/conv/loaders/loader_channel_l.h
steel/conv/loaders/loader_channel_n.h
steel/conv/loaders/loader_general.h
steel/conv/kernels/steel_conv.h
steel/conv/kernels/steel_conv_general.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h
)
foreach(KERNEL ${KERNELS})
build_kernel(${KERNEL})
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
endforeach()
if (NOT MLX_METAL_JIT)
build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h)
build_kernel(binary_two binary_two.h)
build_kernel(copy copy.h)
build_kernel(
fft
fft.h
fft/radix.h
fft/readwrite.h
)
build_kernel(
reduce
atomic.h
reduction/ops.h
reduction/reduce_init.h
reduction/reduce_all.h
reduction/reduce_col.h
reduction/reduce_row.h
)
build_kernel(
quantized
quantized.h
${STEEL_HEADERS}
)
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)
build_kernel(
steel/conv/kernels/steel_conv
${STEEL_HEADERS}
)
build_kernel(
steel/conv/kernels/steel_conv_general
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_fused
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_masked
${STEEL_HEADERS}
)
build_kernel(
steel/gemm/kernels/steel_gemm_splitk
${STEEL_HEADERS}
)
build_kernel(gemv_masked steel/utils.h)
set(
STEEL_KERNELS
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/conv/kernels/steel_conv_general.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_fused.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_masked.metal
${CMAKE_CURRENT_SOURCE_DIR}/steel/gemm/kernels/steel_gemm_splitk.metal
)
set(
STEEL_HEADERS
steel/defines.h
steel/utils.h
steel/conv/conv.h
steel/conv/loader.h
steel/conv/loaders/loader_channel_l.h
steel/conv/loaders/loader_channel_n.h
steel/conv/loaders/loader_general.h
steel/conv/kernels/steel_conv.h
steel/conv/kernels/steel_conv_general.h
steel/gemm/gemm.h
steel/gemm/mma.h
steel/gemm/loader.h
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h
)
foreach(KERNEL ${STEEL_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()
endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib

View File

@@ -6,7 +6,7 @@
using namespace metal;
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
typedef bfloat bfloat16_t;

View File

@@ -369,7 +369,7 @@ instantiate_metal_math_funcs(
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if (MLX_METAL_VERSION >= 310) || (__METAL_VERSION__ >= 310)
#if defined METAL_3_1 || (__METAL_VERSION__ >= 310)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)

View File

@@ -36,39 +36,6 @@ template <typename T, typename U, typename Op>
c[index] = Op()(a[index], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[0], b[offset]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[offset], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
c[offset] = Op()(a[offset], b[offset]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,

View File

@@ -4,94 +4,148 @@
#include <metal_math>
// clang-format off
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary(name, itype, otype, op, bopt) \
template \
[[host_name(name)]] [[kernel]] void binary_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_integer(op) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t)
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name("g" #dims name)]] [[kernel]] void \
binary_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \
instantiate_binary_all(op, float32, float, float) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name("g1" name)]] [[kernel]] void \
binary_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name("g2" name)]] [[kernel]] void \
binary_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("g3" name)]] [[kernel]] void \
binary_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
#define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_integer(op) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
instantiate_binary_float(op)
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name("gn" name)]] [[kernel]] void binary_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_types_bool(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_all(op, uint8, uint8_t, bool) \
instantiate_binary_all(op, uint16, uint16_t, bool) \
instantiate_binary_all(op, uint32, uint32_t, bool) \
instantiate_binary_all(op, uint64, uint64_t, bool) \
instantiate_binary_all(op, int8, int8_t, bool) \
instantiate_binary_all(op, int16, int16_t, bool) \
instantiate_binary_all(op, int32, int32_t, bool) \
instantiate_binary_all(op, int64, int64_t, bool) \
instantiate_binary_all(op, float16, half, bool) \
instantiate_binary_all(op, float32, float, bool) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bool) \
instantiate_binary_all(op, complex64, complex64_t, bool)
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g(#name #tname, itype, otype, op) \
instantiate_binary_g_nd(#name #tname, itype, otype, op)
instantiate_binary_types(Add)
instantiate_binary_types(Divide)
instantiate_binary_types_bool(Equal)
instantiate_binary_types_bool(Greater)
instantiate_binary_types_bool(GreaterEqual)
instantiate_binary_types_bool(Less)
instantiate_binary_types_bool(LessEqual)
instantiate_binary_types_bool(NotEqual)
instantiate_binary_float(LogAddExp)
instantiate_binary_types(Maximum)
instantiate_binary_types(Minimum)
instantiate_binary_types(Multiply)
instantiate_binary_types(Subtract)
instantiate_binary_types(Power)
instantiate_binary_types(Remainder)
instantiate_binary_float(ArcTan2)
#define instantiate_binary_integer(name, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op)
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_integer(name, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
#define instantiate_binary_types_bool(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, bool, op) \
instantiate_binary_all(name, uint16, uint16_t, bool, op) \
instantiate_binary_all(name, uint32, uint32_t, bool, op) \
instantiate_binary_all(name, uint64, uint64_t, bool, op) \
instantiate_binary_all(name, int8, int8_t, bool, op) \
instantiate_binary_all(name, int16, int16_t, bool, op) \
instantiate_binary_all(name, int32, int32_t, bool, op) \
instantiate_binary_all(name, int64, int64_t, bool, op) \
instantiate_binary_all(name, float16, half, bool, op) \
instantiate_binary_all(name, float32, float, bool, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \
instantiate_binary_all(name, complex64, complex64_t, bool, op)
instantiate_binary_types(add, Add)
instantiate_binary_types(div, Divide)
instantiate_binary_types_bool(eq, Equal)
instantiate_binary_types_bool(ge, Greater)
instantiate_binary_types_bool(geq, GreaterEqual)
instantiate_binary_types_bool(le, Less)
instantiate_binary_types_bool(leq, LessEqual)
instantiate_binary_types_bool(neq, NotEqual)
instantiate_binary_float(lae, LogAddExp)
instantiate_binary_types(max, Maximum)
instantiate_binary_types(min, Minimum)
instantiate_binary_types(mul, Multiply)
instantiate_binary_types(sub, Subtract)
instantiate_binary_types(pow, Power)
instantiate_binary_types(rem, Remainder)
instantiate_binary_float(arctan2, ArcTan2)
// NaNEqual only needed for floating point types with boolean output
instantiate_binary_all(NaNEqual, float16, half, bool)
instantiate_binary_all(NaNEqual, float32, float, bool)
instantiate_binary_all(NaNEqual, bfloat16, bfloat16_t, bool)
instantiate_binary_all(NaNEqual, complex64, complex64_t, bool)
instantiate_binary_all(naneq, float16, half, bool, NaNEqual)
instantiate_binary_all(naneq, float32, float, bool, NaNEqual)
instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual)
instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual)
instantiate_binary_all(LogicalOr, bool_, bool, bool)
instantiate_binary_all(LogicalAnd, bool_, bool, bool)
instantiate_binary_all(lor, bool_, bool, bool, LogicalOr)
instantiate_binary_all(land, bool_, bool, bool, LogicalAnd)
// Bitwise ops only need integer types and bool (except for l/r shift)
instantiate_binary_integer(BitwiseAnd)
instantiate_binary_all(BitwiseAnd, bool_, bool, bool)
instantiate_binary_integer(BitwiseOr)
instantiate_binary_all(BitwiseOr, bool_, bool, bool)
instantiate_binary_integer(BitwiseXor)
instantiate_binary_all(BitwiseXor, bool_, bool, bool)
instantiate_binary_integer(LeftShift)
instantiate_binary_integer(RightShift) // clang-format on
instantiate_binary_integer(bitwise_and, BitwiseAnd)
instantiate_binary_all(bitwise_and, bool_, bool, bool, BitwiseAnd)
instantiate_binary_integer(bitwise_or, BitwiseOr)
instantiate_binary_all(bitwise_or, bool_, bool, bool, BitwiseOr)
instantiate_binary_integer(bitwise_xor, BitwiseXor)
instantiate_binary_all(bitwise_xor, bool_, bool, bool, BitwiseXor)
instantiate_binary_integer(left_shift, LeftShift)
instantiate_binary_integer(right_shift, RightShift) // clang-format on

View File

@@ -48,48 +48,6 @@ template <typename T, typename U, typename Op>
d[index] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[0], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[offset], b[0]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
auto out = Op()(a[offset], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,

View File

@@ -7,37 +7,99 @@
#include "mlx/backend/metal/kernels/binary_ops.h"
#include "mlx/backend/metal/kernels/binary_two.h"
#define instantiate_binary_all(op, tname, itype, otype) \
instantiate_kernel("ss" #op #tname, binary_ss, itype, otype, op) \
instantiate_kernel("sv" #op #tname, binary_sv, itype, otype, op) \
instantiate_kernel("vs" #op #tname, binary_vs, itype, otype, op) \
instantiate_kernel("vv" #op #tname, binary_vv, itype, otype, op) \
instantiate_kernel("sv2" #op #tname, binary_sv2, itype, otype, op) \
instantiate_kernel("vs2" #op #tname, binary_vs2, itype, otype, op) \
instantiate_kernel("vv2" #op #tname, binary_vv2, itype, otype, op) \
instantiate_kernel("gn" #op #tname, binary_g, itype, otype, op) \
instantiate_kernel("g1" #op #tname, binary_g_nd1, itype, otype, op) \
instantiate_kernel("g2" #op #tname, binary_g_nd2, itype, otype, op) \
instantiate_kernel("g3" #op #tname, binary_g_nd3, itype, otype, op) \
instantiate_kernel("g4" #op #tname, binary_g_nd, itype, otype, op, 4) \
instantiate_kernel("g5" #op #tname, binary_g_nd, itype, otype, op, 5)
#define instantiate_binary(name, itype, otype, op, bopt) \
template [[host_name(name)]] [[kernel]] void \
binary_##bopt<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
uint index [[thread_position_in_grid]]);
#define instantiate_binary_float(op) \
instantiate_binary_all(op, float16, half, half) \
instantiate_binary_all(op, float32, float, float) \
instantiate_binary_all(op, bfloat16, bfloat16_t, bfloat16_t)
#define instantiate_binary_g_dim(name, itype, otype, op, dims) \
template [[host_name("g" #dims name)]] [[kernel]] void \
binary_g_nd<itype, otype, op, dims>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int shape[dims], \
constant const size_t a_strides[dims], \
constant const size_t b_strides[dims], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_types(op) \
instantiate_binary_all(op, bool_, bool, bool) \
instantiate_binary_all(op, uint8, uint8_t, uint8_t) \
instantiate_binary_all(op, uint16, uint16_t, uint16_t) \
instantiate_binary_all(op, uint32, uint32_t, uint32_t) \
instantiate_binary_all(op, uint64, uint64_t, uint64_t) \
instantiate_binary_all(op, int8, int8_t, int8_t) \
instantiate_binary_all(op, int16, int16_t, int16_t) \
instantiate_binary_all(op, int32, int32_t, int32_t) \
instantiate_binary_all(op, int64, int64_t, int64_t) \
instantiate_binary_all(op, complex64, complex64_t, complex64_t) \
instantiate_binary_float(op)
#define instantiate_binary_g_nd(name, itype, otype, op) \
template [[host_name("g1" name)]] [[kernel]] void \
binary_g_nd1<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t& a_stride, \
constant const size_t& b_stride, \
uint index [[thread_position_in_grid]]); \
template [[host_name("g2" name)]] [[kernel]] void \
binary_g_nd2<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[2], \
constant const size_t b_strides[2], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("g3" name)]] [[kernel]] void \
binary_g_nd3<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const size_t a_strides[3], \
constant const size_t b_strides[3], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
instantiate_binary_g_dim(name, itype, otype, op, 4) \
instantiate_binary_g_dim(name, itype, otype, op, 5)
instantiate_binary_types(DivMod) // clang-format on
#define instantiate_binary_g(name, itype, otype, op) \
template [[host_name("gn" name)]] [[kernel]] void \
binary_g<itype, otype, op>( \
device const itype* a, \
device const itype* b, \
device otype* c, \
device otype* d, \
constant const int* shape, \
constant const size_t* a_strides, \
constant const size_t* b_strides, \
constant const int& ndim, \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]);
#define instantiate_binary_all(name, tname, itype, otype, op) \
instantiate_binary("ss" #name #tname, itype, otype, op, ss) \
instantiate_binary("sv" #name #tname, itype, otype, op, sv) \
instantiate_binary("vs" #name #tname, itype, otype, op, vs) \
instantiate_binary("vv" #name #tname, itype, otype, op, vv) \
instantiate_binary_g(#name #tname, itype, otype, op) \
instantiate_binary_g_nd(#name #tname, itype, otype, op)
#define instantiate_binary_float(name, op) \
instantiate_binary_all(name, float16, half, half, op) \
instantiate_binary_all(name, float32, float, float, op) \
instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op)
#define instantiate_binary_types(name, op) \
instantiate_binary_all(name, bool_, bool, bool, op) \
instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \
instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \
instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \
instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \
instantiate_binary_all(name, int8, int8_t, int8_t, op) \
instantiate_binary_all(name, int16, int16_t, int16_t, op) \
instantiate_binary_all(name, int32, int32_t, int32_t, op) \
instantiate_binary_all(name, int64, int64_t, int64_t, op) \
instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \
instantiate_binary_float(name, op)
instantiate_binary_types(divmod, DivMod) // clang-format on

View File

@@ -344,12 +344,12 @@ winograd_conv_2d_weight_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize G matrix
simdgroup_matrix<float, 8, 8> G;
simdgroup_matrix<T, 8, 8> G;
G.thread_elements()[0] = WGT::wt_transform[sm][sn];
G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1];
// Initialize Gt matrix
simdgroup_matrix<float, 8, 8> Gt;
simdgroup_matrix<T, 8, 8> Gt;
Gt.thread_elements()[0] = WGT::wt_transform[sn][sm];
Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm];
@@ -381,15 +381,15 @@ winograd_conv_2d_weight_transform(
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for (int c = 0; c < BC; ++c) {
simdgroup_matrix<float, 8, 8> g;
simdgroup_matrix<T, 8, 8> g;
g.thread_elements()[0] =
sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0);
g.thread_elements()[1] =
sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0);
simdgroup_matrix<float, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = static_cast<T>(g_out.thread_elements()[0]);
wt_out_1[c * O] = static_cast<T>(g_out.thread_elements()[1]);
simdgroup_matrix<T, 8, 8> g_out = (G * g) * Gt;
wt_out_0[c * O] = g_out.thread_elements()[0];
wt_out_1[c * O] = g_out.thread_elements()[1];
}
wt_in += BC;
@@ -433,12 +433,12 @@ winograd_conv_2d_input_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize B matrix
simdgroup_matrix<float, 8, 8> B;
simdgroup_matrix<T, 8, 8> B;
B.thread_elements()[0] = WGT::in_transform[sm][sn];
B.thread_elements()[1] = WGT::in_transform[sm][sn + 1];
// Initialize Bt matrix
simdgroup_matrix<float, 8, 8> Bt;
simdgroup_matrix<T, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::in_transform[sn][sm];
Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm];
@@ -493,13 +493,13 @@ winograd_conv_2d_input_transform(
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) {
simdgroup_matrix<float, 8, 8> I;
simdgroup_matrix<T, 8, 8> I;
I.thread_elements()[0] = Is[sm][sn][c];
I.thread_elements()[1] = Is[sm][sn + 1][c];
simdgroup_matrix<float, 8, 8> I_out = (Bt * I) * B;
inp_out_0[c] = static_cast<T>(I_out.thread_elements()[0]);
inp_out_1[c] = static_cast<T>(I_out.thread_elements()[1]);
simdgroup_matrix<T, 8, 8> I_out = (Bt * I) * B;
inp_out_0[c] = I_out.thread_elements()[0];
inp_out_1[c] = I_out.thread_elements()[1];
}
inp_in += BC;
@@ -543,12 +543,12 @@ winograd_conv_2d_output_transform(
const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Initialize A matrix
simdgroup_matrix<float, 8, 8> B;
simdgroup_matrix<T, 8, 8> B;
B.thread_elements()[0] = WGT::out_transform[sm][sn];
B.thread_elements()[1] = WGT::out_transform[sm][sn + 1];
// Initialize At matrix
simdgroup_matrix<float, 8, 8> Bt;
simdgroup_matrix<T, 8, 8> Bt;
Bt.thread_elements()[0] = WGT::out_transform[sn][sm];
Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm];
@@ -597,16 +597,16 @@ winograd_conv_2d_output_transform(
threadgroup_barrier(mem_flags::mem_threadgroup);
// Do transform and store the result
for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) {
simdgroup_matrix<float, 8, 8> O_mat;
simdgroup_matrix<T, 8, 8> O_mat;
O_mat.thread_elements()[0] = out_in_0[c];
O_mat.thread_elements()[1] = out_in_1[c];
simdgroup_matrix<float, 8, 8> O_out = (Bt * (O_mat * B));
simdgroup_matrix<T, 8, 8> O_out = (Bt * (O_mat * B));
if ((sm < M) && (sn < M)) {
Os[sm][sn][c] = static_cast<T>(O_out.thread_elements()[0]);
Os[sm][sn][c] = O_out.thread_elements()[0];
}
if ((sm < M) && ((sn + 1) < M)) {
Os[sm][sn + 1][c] = static_cast<T>(O_out.thread_elements()[1]);
Os[sm][sn + 1][c] = O_out.thread_elements()[1];
}
}
@@ -650,5 +650,4 @@ winograd_conv_2d_output_transform(
// clang-format off
instantiate_winograd_conv_2d(float32, float);
instantiate_winograd_conv_2d(bfloat16, bfloat16_t);
instantiate_winograd_conv_2d(float16, half); // clang-format on

View File

@@ -16,26 +16,6 @@ template <typename T, typename U>
dst[index] = static_cast<U>(src[index]);
}
template <typename T, typename U>
[[kernel]] void copy_s2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
dst[offset] = static_cast<U>(src[0]);
}
template <typename T, typename U>
[[kernel]] void copy_v2(
device const T* src [[buffer(0)]],
device U* dst [[buffer(1)]],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
size_t offset = index.x + grid_dim.x * size_t(index.y);
dst[offset] = static_cast<U>(src[offset]);
}
template <typename T, typename U>
[[kernel]] void copy_g_nd1(
device const T* src [[buffer(0)]],

View File

@@ -5,23 +5,95 @@
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/copy.h"
#define instantiate_copy(name, itype, otype, ctype) \
template [[host_name(name)]] [[kernel]] void copy_##ctype<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
uint index [[thread_position_in_grid]]);
#define instantiate_copy_g_dim(name, itype, otype, dims) \
template [[host_name("g" #dims "_" name)]] [[kernel]] void \
copy_g_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg" #dims "_" name)]] [[kernel]] void \
copy_gg_nd<itype, otype, dims>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_g_nd(name, itype, otype) \
template [[host_name("g1_" name)]] [[kernel]] void copy_g_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("g2_" name)]] [[kernel]] void copy_g_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint2 index [[thread_position_in_grid]], \
uint2 grid_dim [[threads_per_grid]]); \
template [[host_name("g3_" name)]] [[kernel]] void copy_g_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg1_" name )]] [[kernel]] void \
copy_gg_nd1<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t& src_stride [[buffer(3)]], \
constant const int64_t& dst_stride [[buffer(4)]], \
uint index [[thread_position_in_grid]]); \
template [[host_name("gg2_" name)]] [[kernel]] void \
copy_gg_nd2<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint2 index [[thread_position_in_grid]]); \
template [[host_name("gg3_" name)]] [[kernel]] void \
copy_gg_nd3<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
uint3 index [[thread_position_in_grid]]); \
instantiate_copy_g_dim(name, itype, otype, 4) \
instantiate_copy_g_dim(name, itype, otype, 5)
#define instantiate_copy_g(name, itype, otype) \
template [[host_name("g_" name)]] [[kernel]] void copy_g<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]], \
uint3 grid_dim [[threads_per_grid]]); \
template [[host_name("gg_" name)]] [[kernel]] void copy_gg<itype, otype>( \
device const itype* src [[buffer(0)]], \
device otype* dst [[buffer(1)]], \
constant const int* src_shape [[buffer(2)]], \
constant const int64_t* src_strides [[buffer(3)]], \
constant const int64_t* dst_strides [[buffer(4)]], \
constant const int& ndim [[buffer(5)]], \
uint3 index [[thread_position_in_grid]]);
#define instantiate_copy_all(tname, itype, otype) \
instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \
instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \
instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \
instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \
instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype) \
instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype) \
instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype) \
instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, itype, otype) \
instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, itype, otype) \
instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, itype, otype) \
instantiate_kernel("g4_copy" #tname, copy_g_nd, itype, otype, 4) \
instantiate_kernel("g5_copy" #tname, copy_g_nd, itype, otype, 5) \
instantiate_kernel("gg4_copy" #tname, copy_gg_nd, itype, otype, 4) \
instantiate_kernel("gg5_copy" #tname, copy_gg_nd, itype, otype, 5) \
instantiate_kernel("g_copy" #tname, copy_g, itype, otype) \
instantiate_kernel("gg_copy" #tname, copy_gg, itype, otype)
instantiate_copy("s_copy" #tname, itype, otype, s) \
instantiate_copy("v_copy" #tname, itype, otype, v) \
instantiate_copy_g("copy" #tname, itype, otype) \
instantiate_copy_g_nd("copy" #tname, itype, otype)
#define instantiate_copy_itype(itname, itype) \
instantiate_copy_all(itname ##bool_, itype, bool) \

View File

@@ -13,11 +13,3 @@ static MTL_CONST constexpr int REDUCE_N_READS = 16;
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
static MTL_CONST constexpr int RMS_N_READS = 4;
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
// Instantiate a templated kernel.
// Extra args are used as template parameters:
// e.g. instantiate_kernel(binary_int, binary, a, b) ->
// [[host_name(binary_int)]] [kernel] binary<a, b>
#define instantiate_kernel(name, func, ...) \
template [[host_name( \
name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;

View File

@@ -83,7 +83,6 @@ float expm1f(float a) {
r = expm1f_scaled_unchecked(a, 1.0f);
/* handle severe overflow and underflow */
if (abs(a - 1.0f) > 88.0f) {
r = pow(2, a);
r = fma(r, r, -1.0f);
}
return r;

View File

@@ -1,486 +0,0 @@
// Copyright © 2024 Apple Inc.
// Metal FFT using Stockham's algorithm
//
// References:
// - VkFFT (https://github.com/DTolm/VkFFT)
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
#include <metal_common>
#include "mlx/backend/metal/kernels/fft/radix.h"
#include "mlx/backend/metal/kernels/fft/readwrite.h"
#include "mlx/backend/metal/kernels/steel/defines.h"
using namespace metal;
#define MAX_RADIX 13
// Reached when elems_per_thread_ = 6, max_radix = 13
// and some threads have to do 3 radix 6s requiring 18 float2s.
#define MAX_OUTPUT_SIZE 18
// Specialize for a particular value of N at runtime
STEEL_CONST bool inv_ [[function_constant(0)]];
STEEL_CONST bool is_power_of_2_ [[function_constant(1)]];
STEEL_CONST int elems_per_thread_ [[function_constant(2)]];
// rader_m = n / rader_n
STEEL_CONST int rader_m_ [[function_constant(3)]];
// Stockham steps
STEEL_CONST int radix_13_steps_ [[function_constant(4)]];
STEEL_CONST int radix_11_steps_ [[function_constant(5)]];
STEEL_CONST int radix_8_steps_ [[function_constant(6)]];
STEEL_CONST int radix_7_steps_ [[function_constant(7)]];
STEEL_CONST int radix_6_steps_ [[function_constant(8)]];
STEEL_CONST int radix_5_steps_ [[function_constant(9)]];
STEEL_CONST int radix_4_steps_ [[function_constant(10)]];
STEEL_CONST int radix_3_steps_ [[function_constant(11)]];
STEEL_CONST int radix_2_steps_ [[function_constant(12)]];
// Rader steps
STEEL_CONST int rader_13_steps_ [[function_constant(13)]];
STEEL_CONST int rader_11_steps_ [[function_constant(14)]];
STEEL_CONST int rader_8_steps_ [[function_constant(15)]];
STEEL_CONST int rader_7_steps_ [[function_constant(16)]];
STEEL_CONST int rader_6_steps_ [[function_constant(17)]];
STEEL_CONST int rader_5_steps_ [[function_constant(18)]];
STEEL_CONST int rader_4_steps_ [[function_constant(19)]];
STEEL_CONST int rader_3_steps_ [[function_constant(20)]];
STEEL_CONST int rader_2_steps_ [[function_constant(21)]];
// See "radix.h" for radix codelets
typedef void (*RadixFunc)(thread float2*, thread float2*);
// Perform a single radix n butterfly with appropriate twiddles
template <int radix, RadixFunc radix_func>
METAL_FUNC void radix_butterfly(
int i,
int p,
thread float2* x,
thread short* indices,
thread float2* y) {
// i: the index in the overall DFT that we're processing.
// p: the size of the DFTs we're merging at this step.
// m: how many threads are working on this DFT.
int k, j;
// Use faster bitwise operations when working with powers of two
constexpr bool radix_p_2 = (radix & (radix - 1)) == 0;
if (radix_p_2 && is_power_of_2_) {
constexpr short power = __builtin_ctz(radix);
k = i & (p - 1);
j = ((i - k) << power) + k;
} else {
k = i % p;
j = (i / p) * radix * p + k;
}
// Apply twiddles
if (p > 1) {
float2 twiddle_1 = get_twiddle(k, radix * p);
float2 twiddle = twiddle_1;
x[1] = complex_mul(x[1], twiddle);
STEEL_PRAGMA_UNROLL
for (int t = 2; t < radix; t++) {
twiddle = complex_mul(twiddle, twiddle_1);
x[t] = complex_mul(x[t], twiddle);
}
}
radix_func(x, y);
STEEL_PRAGMA_UNROLL
for (int t = 0; t < radix; t++) {
indices[t] = j + t * p;
}
}
// Perform all the radix steps required for a
// particular radix size n.
template <int radix, RadixFunc radix_func>
METAL_FUNC void radix_n_steps(
int i,
thread int* p,
int m,
int n,
int num_steps,
thread float2* inputs,
thread short* indices,
thread float2* values,
threadgroup float2* buf) {
int m_r = n / radix;
// When combining different sized radices, we have to do
// multiple butterflies in a single thread.
// E.g. n = 28 = 4 * 7
// 4 threads, 7 elems_per_thread
// All threads do 1 radix7 butterfly.
// 3 threads do 2 radix4 butterflies.
// 1 thread does 1 radix4 butterfly.
int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix;
int index = 0;
int r_index = 0;
for (int s = 0; s < num_steps; s++) {
for (int t = 0; t < max_radices_per_thread; t++) {
index = i + t * m;
if (index < m_r) {
for (int r = 0; r < radix; r++) {
inputs[r] = buf[index + r * m_r];
}
radix_butterfly<radix, radix_func>(
index, *p, inputs, indices + t * radix, values + t * radix);
}
}
// Wait until all threads have read their inputs into thread local mem
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int t = 0; t < max_radices_per_thread; t++) {
index = i + t * m;
if (index < m_r) {
for (int r = 0; r < radix; r++) {
r_index = t * radix + r;
buf[indices[r_index]] = values[r_index];
}
}
}
// Wait until all threads have written back to threadgroup mem
threadgroup_barrier(mem_flags::mem_threadgroup);
*p *= radix;
}
}
#define RADIX_STEP(radix, radix_func, num_steps) \
radix_n_steps<radix, radix_func>( \
fft_idx, p, m, n, num_steps, inputs, indices, values, buf);
template <bool rader = false>
METAL_FUNC void
perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) {
float2 inputs[MAX_RADIX];
short indices[MAX_OUTPUT_SIZE];
float2 values[MAX_OUTPUT_SIZE];
RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_);
RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_);
RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_);
RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_);
RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_);
RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_);
RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_);
RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_);
RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_);
}
// Each FFT is computed entirely in shared GPU memory.
//
// N is decomposed into radix-n DFTs:
// e.g. 128 = 2 * 4 * 4 * 4
template <int tg_mem_size, typename in_T, typename out_T>
[[kernel]] void fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
threadgroup float2 shared_in[tg_mem_size];
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load();
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
int fft_idx = elem.z; // Thread index in DFT
int m = grid.z; // Threads per DFT
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
threadgroup float2* buf = &shared_in[tg_idx];
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write();
}
template <int tg_mem_size, typename in_T, typename out_T>
[[kernel]] void rader_fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
const device float2* raders_b_q [[buffer(2)]],
const device short* raders_g_q [[buffer(3)]],
const device short* raders_g_minus_q [[buffer(4)]],
constant const int& n,
constant const int& batch_size,
constant const int& rader_n,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Use Rader's algorithm to compute fast FFTs
// when a prime factor `p` of `n` is greater than 13 but
// has `p - 1` Stockham decomposable into to prime factors <= 13.
//
// E.g. n = 102
// = 2 * 3 * 17
// . = 2 * 3 * RADER(16)
// . = 2 * 3 * RADER(4 * 4)
//
// In numpy:
// x_perm = x[g_q]
// y = np.fft.fft(x_perm) * b_q
// z = np.fft.ifft(y) + x[0]
// out = z[g_minus_q]
// out[0] = x[1:].sum()
//
// Where the g_q and g_minus_q are permutations formed
// by the group under multiplicative modulo N using the
// primitive root of N and b_q is a constant.
// See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm
//
// Rader's uses fewer operations than Bluestein's and so
// is more accurate. It's also faster in most cases.
threadgroup float2 shared_in[tg_mem_size];
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load();
threadgroup_barrier(mem_flags::mem_threadgroup);
// The number of the threads we're using for each DFT
int m = grid.z;
int fft_idx = elem.z;
int tg_idx = elem.y * n;
threadgroup float2* buf = &shared_in[tg_idx];
// rader_m = n / rader_n;
int rader_m = rader_m_;
// We have to load two x_0s for each thread since sometimes
// elems_per_thread_ crosses a boundary.
// E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4
// 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8
// 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
short x_0_index =
metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1);
float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]};
// Do the Rader permutation in shared memory
float2 temp[MAX_RADIX];
int max_index = n - rader_m - 1;
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
short g_q = raders_g_q[index / rader_m];
temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
buf[index + rader_m] = temp[e];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Rader FFT on x[rader_m:]
int p = 1;
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
// x_1 + ... + x_n is computed for us in the first FFT step so
// we save it in the first rader_m indices of the array for later.
int x_sum_index = metal::min(fft_idx, rader_m - 1);
buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)];
float2 inv = {1.0f, -1.0f};
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
short interleaved_index =
index / rader_m + (index % rader_m) * (rader_n - 1);
temp[e] = complex_mul(
buf[rader_m + interleaved_index],
raders_b_q[interleaved_index % (rader_n - 1)]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
buf[rader_m + index] = temp[e] * inv;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Rader IFFT on x[rader_m:]
p = 1;
perform_fft</*rader=*/true>(fft_idx, &p, m, n - rader_m, buf + rader_m);
float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)};
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1);
short diff_index = index / (rader_n - 1) - x_0_index;
temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index];
}
// Use the sum of elements that was computed in the first FFT
float2 x_sum = buf[x_0_index] + x_0[0];
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int e = 0; e < elems_per_thread_; e++) {
short index = metal::min(fft_idx * elems_per_thread_ + e, max_index);
short g_q_index = index % (rader_n - 1);
short g_q = raders_g_minus_q[g_q_index];
short out_index = index - g_q_index + g_q + (index / (rader_n - 1));
buf[out_index] = temp[e];
}
buf[x_0_index * rader_n] = x_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
p = rader_n;
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write();
}
template <int tg_mem_size, typename in_T, typename out_T>
[[kernel]] void bluestein_fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
const device float2* w_q [[buffer(2)]],
const device float2* w_k [[buffer(3)]],
constant const int& length,
constant const int& n,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Computes arbitrary length FFTs with Bluestein's algorithm
//
// In numpy:
// bluestein_n = next_power_of_2(2*n - 1)
// out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q)
//
// Where w_k and w_q are precomputed on CPU in high precision as:
// w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2))
// w_q = np.fft.fft(1/w_k[-n:])
threadgroup float2 shared_in[tg_mem_size];
thread ReadWriter<in_T, out_T> read_writer = ReadWriter<in_T, out_T>(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load_padded(length, w_k);
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
int fft_idx = elem.z; // Thread index in DFT
int m = grid.z; // Threads per DFT
int tg_idx = elem.y * n; // Index of this DFT in threadgroup
threadgroup float2* buf = &shared_in[tg_idx];
// fft
perform_fft(fft_idx, &p, m, n, buf);
float2 inv = float2(1.0f, -1.0f);
for (int t = 0; t < elems_per_thread_; t++) {
int index = fft_idx + t * m;
buf[index] = complex_mul(buf[index], w_q[index]) * inv;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ifft
p = 1;
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write_padded(length, w_k);
}
template <
int tg_mem_size,
typename in_T,
typename out_T,
int step,
bool real = false>
[[kernel]] void four_step_fft(
const device in_T* in [[buffer(0)]],
device out_T* out [[buffer(1)]],
constant const int& n1,
constant const int& n2,
constant const int& batch_size,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
// Fast four step FFT implementation for powers of 2.
int overall_n = n1 * n2;
int n = step == 0 ? n1 : n2;
int stride = step == 0 ? n2 : n1;
// The number of the threads we're using for each DFT
int m = grid.z;
int fft_idx = elem.z;
threadgroup float2 shared_in[tg_mem_size];
threadgroup float2* buf = &shared_in[elem.y * n];
using read_writer_t = ReadWriter<in_T, out_T, step, real>;
read_writer_t read_writer = read_writer_t(
in,
&shared_in[0],
out,
n,
batch_size,
elems_per_thread_,
elem,
grid,
inv_);
if (read_writer.out_of_bounds()) {
return;
};
read_writer.load_strided(stride, overall_n);
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
perform_fft(fft_idx, &p, m, n, buf);
read_writer.write_strided(stride, overall_n);
}

View File

@@ -1,67 +1,199 @@
// Copyright © 2024 Apple Inc.
// Metal FFT using Stockham's algorithm
//
// References:
// - VkFFT (https://github.com/DTolm/VkFFT)
// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html)
#include <metal_common>
#include <metal_math>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/fft.h"
#include "mlx/backend/metal/kernels/utils.h"
#define instantiate_fft(tg_mem_size, in_T, out_T) \
instantiate_kernel( \
"fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
fft, \
tg_mem_size, \
in_T, \
out_T)
using namespace metal;
#define instantiate_rader(tg_mem_size, in_T, out_T) \
instantiate_kernel( \
"rader_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
rader_fft, \
tg_mem_size, \
in_T, \
out_T)
float2 complex_mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x - a.y * b.y;
c.y = a.x * b.y + a.y * b.x;
return c;
}
#define instantiate_bluestein(tg_mem_size, in_T, out_T) \
instantiate_kernel( \
"bluestein_fft_mem_" #tg_mem_size "_" #in_T "_" #out_T, \
bluestein_fft, \
tg_mem_size, \
in_T, \
out_T)
float2 get_twiddle(int k, int p) {
float theta = -1.0f * k * M_PI_F / (2 * p);
#define instantiate_four_step(tg_mem_size, in_T, out_T, step, real) \
instantiate_kernel( \
"four_step_mem_" #tg_mem_size "_" #in_T "_" #out_T "_" #step "_" #real, \
four_step_fft, \
tg_mem_size, \
in_T, \
out_T, \
step, \
real)
float2 twiddle;
twiddle.x = metal::fast::cos(theta);
twiddle.y = metal::fast::sin(theta);
return twiddle;
}
// single threaded radix2 implemetation
void radix2(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
// The index within this sub-DFT
int k = i & (p - 1);
float2 twiddle = get_twiddle(k, p);
float2 z = complex_mul(x_1, twiddle);
float2 y_0 = x_0 + z;
float2 y_1 = x_0 - z;
int j = (i << 1) - k;
write_buf[j] = y_0;
write_buf[j + p] = y_1;
}
// single threaded radix4 implemetation
void radix4(
int i,
int p,
int m,
threadgroup float2* read_buf,
threadgroup float2* write_buf) {
float2 x_0 = read_buf[i];
float2 x_1 = read_buf[i + m];
float2 x_2 = read_buf[i + 2 * m];
float2 x_3 = read_buf[i + 3 * m];
// The index within this sub-DFT
int k = i & (p - 1);
float2 twiddle = get_twiddle(k, p);
// e^a * e^b = e^(a + b)
float2 twiddle_2 = complex_mul(twiddle, twiddle);
float2 twiddle_3 = complex_mul(twiddle, twiddle_2);
x_1 = complex_mul(x_1, twiddle);
x_2 = complex_mul(x_2, twiddle_2);
x_3 = complex_mul(x_3, twiddle_3);
float2 minus_i;
minus_i.x = 0;
minus_i.y = -1;
// Hard coded twiddle factors for DFT4
float2 z_0 = x_0 + x_2;
float2 z_1 = x_0 - x_2;
float2 z_2 = x_1 + x_3;
float2 z_3 = complex_mul(x_1 - x_3, minus_i);
float2 y_0 = z_0 + z_2;
float2 y_1 = z_1 + z_3;
float2 y_2 = z_0 - z_2;
float2 y_3 = z_1 - z_3;
int j = ((i - k) << 2) + k;
write_buf[j] = y_0;
write_buf[j + p] = y_1;
write_buf[j + 2 * p] = y_2;
write_buf[j + 3 * p] = y_3;
}
// Each FFT is computed entirely in shared GPU memory.
//
// N is decomposed into radix-2 and radix-4 DFTs:
// e.g. 128 = 2 * 4 * 4 * 4
//
// At each step we use n / 4 threads, each performing
// a single-threaded radix-4 or radix-2 DFT.
//
// We provide the number of radix-2 and radix-4
// steps at compile time for a ~20% performance boost.
template <size_t n, size_t radix_2_steps, size_t radix_4_steps>
[[kernel]] void fft(
const device float2* in [[buffer(0)]],
device float2* out [[buffer(1)]],
uint3 thread_position_in_grid [[thread_position_in_grid]],
uint3 threads_per_grid [[threads_per_grid]]) {
// Index of the DFT in batch
int batch_idx = thread_position_in_grid.x * n;
// The index in the DFT we're working on
int i = thread_position_in_grid.y;
// The number of the threads we're using for each DFT
int m = threads_per_grid.y;
// Allocate 2 shared memory buffers for Stockham.
// We alternate reading from one and writing to the other at each radix step.
threadgroup float2 shared_in[n];
threadgroup float2 shared_out[n];
// Pointers to facilitate Stockham buffer swapping
threadgroup float2* read_buf = shared_in;
threadgroup float2* write_buf = shared_out;
threadgroup float2* tmp;
// Copy input into shared memory
shared_in[i] = in[batch_idx + i];
shared_in[i + m] = in[batch_idx + i + m];
shared_in[i + 2 * m] = in[batch_idx + i + 2 * m];
shared_in[i + 3 * m] = in[batch_idx + i + 3 * m];
threadgroup_barrier(mem_flags::mem_threadgroup);
int p = 1;
for (size_t r = 0; r < radix_2_steps; r++) {
radix2(i, p, m * 2, read_buf, write_buf);
radix2(i + m, p, m * 2, read_buf, write_buf);
p *= 2;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stockham switch of buffers
tmp = write_buf;
write_buf = read_buf;
read_buf = tmp;
}
for (size_t r = 0; r < radix_4_steps; r++) {
radix4(i, p, m, read_buf, write_buf);
p *= 4;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Stockham switch of buffers
tmp = write_buf;
write_buf = read_buf;
read_buf = tmp;
}
// Copy shared memory to output
out[batch_idx + i] = read_buf[i];
out[batch_idx + i + m] = read_buf[i + m];
out[batch_idx + i + 2 * m] = read_buf[i + 2 * m];
out[batch_idx + i + 3 * m] = read_buf[i + 3 * m];
}
#define instantiate_fft(name, n, radix_2_steps, radix_4_steps) \
template [[host_name("fft_" #name)]] [[kernel]] void \
fft<n, radix_2_steps, radix_4_steps>( \
const device float2* in [[buffer(0)]], \
device float2* out [[buffer(1)]], \
uint3 thread_position_in_grid [[thread_position_in_grid]], \
uint3 threads_per_grid [[threads_per_grid]]);
// Explicitly define kernels for each power of 2.
// clang-format off
#define instantiate_ffts(tg_mem_size) \
instantiate_fft(tg_mem_size, float2, float2) \
instantiate_fft(tg_mem_size, float, float2) \
instantiate_fft(tg_mem_size, float2, float) \
instantiate_rader(tg_mem_size, float2, float2) \
instantiate_rader(tg_mem_size, float, float2) \
instantiate_rader(tg_mem_size, float2, float) \
instantiate_bluestein(tg_mem_size, float2, float2) \
instantiate_bluestein(tg_mem_size, float, float2) \
instantiate_bluestein(tg_mem_size, float2, float) \
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/false) \
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/false) \
instantiate_four_step(tg_mem_size, float, float2, 0, /*real=*/true) \
instantiate_four_step(tg_mem_size, float2, float2, 1, /*real=*/true) \
instantiate_four_step(tg_mem_size, float2, float2, 0, /*real=*/true) \
instantiate_four_step(tg_mem_size, float2, float, 1, /*real=*/true)
// It's substantially faster to statically define the
// threadgroup memory size rather than using
// `setThreadgroupMemoryLength` on the compute encoder.
// For non-power of 2 sizes we round up the shared memory.
instantiate_ffts(256)
instantiate_ffts(512)
instantiate_ffts(1024)
instantiate_ffts(2048)
// 4096 is the max that will fit into 32KB of threadgroup memory.
instantiate_ffts(4096) // clang-format on
instantiate_fft(4, /* n= */ 4, /* radix_2_steps= */ 0, /* radix_4_steps= */ 1)
instantiate_fft(8, 8, 1, 1) instantiate_fft(16, 16, 0, 2)
instantiate_fft(32, 32, 1, 2) instantiate_fft(64, 64, 0, 3)
instantiate_fft(128, 128, 1, 3) instantiate_fft(256, 256, 0, 4)
instantiate_fft(512, 512, 1, 4)
instantiate_fft(1024, 1024, 0, 5)
// 2048 is the max that will fit into 32KB of threadgroup memory.
// TODO: implement 4 step FFT for larger n.
instantiate_fft(2048, 2048, 1, 5) // clang-format on

View File

@@ -1,328 +0,0 @@
// Copyright © 2024 Apple Inc.
/* Radix kernels
We provide optimized, single threaded Radix codelets
for n=2,3,4,5,6,7,8,10,11,12,13.
For n=2,3,4,5,6 we hand write the codelets.
For n=8,10,12 we combine smaller codelets.
For n=7,11,13 we use Rader's algorithm which decomposes
them into (n-1)=6,10,12 codelets. */
#pragma once
#include <metal_common>
#include <metal_math>
#include <metal_stdlib>
METAL_FUNC float2 complex_mul(float2 a, float2 b) {
return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}
// Complex mul followed by conjugate
METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) {
return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x);
}
// Compute an FFT twiddle factor
METAL_FUNC float2 get_twiddle(int k, int p) {
float theta = -2.0f * k * M_PI_F / p;
float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)};
return twiddle;
}
METAL_FUNC void radix2(thread float2* x, thread float2* y) {
y[0] = x[0] + x[1];
y[1] = x[0] - x[1];
}
METAL_FUNC void radix3(thread float2* x, thread float2* y) {
float pi_2_3 = -0.8660254037844387;
float2 a_1 = x[1] + x[2];
float2 a_2 = x[1] - x[2];
y[0] = x[0] + a_1;
float2 b_1 = x[0] - 0.5 * a_1;
float2 b_2 = pi_2_3 * a_2;
float2 b_2_j = {-b_2.y, b_2.x};
y[1] = b_1 + b_2_j;
y[2] = b_1 - b_2_j;
}
METAL_FUNC void radix4(thread float2* x, thread float2* y) {
float2 z_0 = x[0] + x[2];
float2 z_1 = x[0] - x[2];
float2 z_2 = x[1] + x[3];
float2 z_3 = x[1] - x[3];
float2 z_3_i = {z_3.y, -z_3.x};
y[0] = z_0 + z_2;
y[1] = z_1 + z_3_i;
y[2] = z_0 - z_2;
y[3] = z_1 - z_3_i;
}
METAL_FUNC void radix5(thread float2* x, thread float2* y) {
float2 root_5_4 = 0.5590169943749475;
float2 sin_2pi_5 = 0.9510565162951535;
float2 sin_1pi_5 = 0.5877852522924731;
float2 a_1 = x[1] + x[4];
float2 a_2 = x[2] + x[3];
float2 a_3 = x[1] - x[4];
float2 a_4 = x[2] - x[3];
float2 a_5 = a_1 + a_2;
float2 a_6 = root_5_4 * (a_1 - a_2);
float2 a_7 = x[0] - a_5 / 4;
float2 a_8 = a_7 + a_6;
float2 a_9 = a_7 - a_6;
float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4;
float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4;
float2 a_10_j = {a_10.y, -a_10.x};
float2 a_11_j = {a_11.y, -a_11.x};
y[0] = x[0] + a_5;
y[1] = a_8 + a_10_j;
y[2] = a_9 + a_11_j;
y[3] = a_9 - a_11_j;
y[4] = a_8 - a_10_j;
}
METAL_FUNC void radix6(thread float2* x, thread float2* y) {
float sin_pi_3 = 0.8660254037844387;
float2 a_1 = x[2] + x[4];
float2 a_2 = x[0] - a_1 / 2;
float2 a_3 = sin_pi_3 * (x[2] - x[4]);
float2 a_4 = x[5] + x[1];
float2 a_5 = x[3] - a_4 / 2;
float2 a_6 = sin_pi_3 * (x[5] - x[1]);
float2 a_7 = x[0] + a_1;
float2 a_3_i = {a_3.y, -a_3.x};
float2 a_6_i = {a_6.y, -a_6.x};
float2 a_8 = a_2 + a_3_i;
float2 a_9 = a_2 - a_3_i;
float2 a_10 = x[3] + a_4;
float2 a_11 = a_5 + a_6_i;
float2 a_12 = a_5 - a_6_i;
y[0] = a_7 + a_10;
y[1] = a_8 - a_11;
y[2] = a_9 + a_12;
y[3] = a_7 - a_10;
y[4] = a_8 + a_11;
y[5] = a_9 - a_12;
}
METAL_FUNC void radix7(thread float2* x, thread float2* y) {
// Rader's algorithm
float2 inv = {1 / 6.0, -1 / 6.0};
// fft
float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]};
radix6(in1, y + 1);
y[0] = y[1] + x[0];
// b_q
y[1] = complex_mul_conj(y[1], float2(-1, 0));
y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879));
y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629));
y[4] = complex_mul_conj(y[4], float2(0, -2.64575131));
y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629));
y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879));
// ifft
radix6(y + 1, x + 1);
y[1] = x[1] * inv + x[0];
y[5] = x[2] * inv + x[0];
y[4] = x[3] * inv + x[0];
y[6] = x[4] * inv + x[0];
y[2] = x[5] * inv + x[0];
y[3] = x[6] * inv + x[0];
}
METAL_FUNC void radix8(thread float2* x, thread float2* y) {
float cos_pi_4 = 0.7071067811865476;
float2 w_0 = {cos_pi_4, -cos_pi_4};
float2 w_1 = {-cos_pi_4, -cos_pi_4};
float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]};
radix4(temp, x);
radix4(temp + 4, x + 4);
y[0] = x[0] + x[4];
y[4] = x[0] - x[4];
float2 x_5 = complex_mul(x[5], w_0);
y[1] = x[1] + x_5;
y[5] = x[1] - x_5;
float2 x_6 = {x[6].y, -x[6].x};
y[2] = x[2] + x_6;
y[6] = x[2] - x_6;
float2 x_7 = complex_mul(x[7], w_1);
y[3] = x[3] + x_7;
y[7] = x[3] - x_7;
}
template <bool raders_perm>
METAL_FUNC void radix10(thread float2* x, thread float2* y) {
float2 w[4];
w[0] = {0.8090169943749475, -0.5877852522924731};
w[1] = {0.30901699437494745, -0.9510565162951535};
w[2] = {-w[1].x, w[1].y};
w[3] = {-w[0].x, w[0].y};
if (raders_perm) {
float2 temp[10] = {
x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]};
radix5(temp, x);
radix5(temp + 5, x + 5);
} else {
float2 temp[10] = {
x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]};
radix5(temp, x);
radix5(temp + 5, x + 5);
}
y[0] = x[0] + x[5];
y[5] = x[0] - x[5];
for (int t = 1; t < 5; t++) {
float2 a = complex_mul(x[t + 5], w[t - 1]);
y[t] = x[t] + a;
y[t + 5] = x[t] - a;
}
}
METAL_FUNC void radix11(thread float2* x, thread float2* y) {
// Raders Algorithm
float2 inv = {1 / 10.0, -1 / 10.0};
// fft
radix10<true>(x + 1, y + 1);
y[0] = y[1] + x[0];
// b_q
y[1] = complex_mul_conj(y[1], float2(-1, 0));
y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649));
y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656));
y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479));
y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150));
y[6] = complex_mul_conj(y[6], float2(0, -3.31662479));
y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150));
y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479));
y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656));
y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649));
// ifft
radix10<false>(y + 1, x + 1);
y[1] = x[1] * inv + x[0];
y[6] = x[2] * inv + x[0];
y[3] = x[3] * inv + x[0];
y[7] = x[4] * inv + x[0];
y[9] = x[5] * inv + x[0];
y[10] = x[6] * inv + x[0];
y[5] = x[7] * inv + x[0];
y[8] = x[8] * inv + x[0];
y[4] = x[9] * inv + x[0];
y[2] = x[10] * inv + x[0];
}
template <bool raders_perm>
METAL_FUNC void radix12(thread float2* x, thread float2* y) {
float2 w[6];
float sin_pi_3 = 0.8660254037844387;
w[0] = {sin_pi_3, -0.5};
w[1] = {0.5, -sin_pi_3};
w[2] = {0, -1};
w[3] = {-0.5, -sin_pi_3};
w[4] = {-sin_pi_3, -0.5};
if (raders_perm) {
float2 temp[12] = {
x[0],
x[3],
x[2],
x[11],
x[8],
x[9],
x[1],
x[7],
x[5],
x[10],
x[4],
x[6]};
radix6(temp, x);
radix6(temp + 6, x + 6);
} else {
float2 temp[12] = {
x[0],
x[2],
x[4],
x[6],
x[8],
x[10],
x[1],
x[3],
x[5],
x[7],
x[9],
x[11]};
radix6(temp, x);
radix6(temp + 6, x + 6);
}
y[0] = x[0] + x[6];
y[6] = x[0] - x[6];
for (int t = 1; t < 6; t++) {
float2 a = complex_mul(x[t + 6], w[t - 1]);
y[t] = x[t] + a;
y[t + 6] = x[t] - a;
}
}
METAL_FUNC void radix13(thread float2* x, thread float2* y) {
// Raders Algorithm
float2 inv = {1 / 12.0, -1 / 12.0};
// fft
radix12<true>(x + 1, y + 1);
y[0] = y[1] + x[0];
// b_q
y[1] = complex_mul_conj(y[1], float2(-1, 0));
y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669));
y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823));
y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161));
y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690));
y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267));
y[7] = complex_mul_conj(y[7], float2(3.60555128, 0));
y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267));
y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690));
y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161));
y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823));
y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669));
// ifft
radix12<false>(y + 1, x + 1);
y[1] = x[1] * inv + x[0];
y[7] = x[2] * inv + x[0];
y[10] = x[3] * inv + x[0];
y[5] = x[4] * inv + x[0];
y[9] = x[5] * inv + x[0];
y[11] = x[6] * inv + x[0];
y[12] = x[7] * inv + x[0];
y[6] = x[8] * inv + x[0];
y[3] = x[9] * inv + x[0];
y[8] = x[10] * inv + x[0];
y[4] = x[11] * inv + x[0];
y[2] = x[12] * inv + x[0];
}

Some files were not shown because too many files have changed in this diff Show More