jagrit's commit files

This commit is contained in:
Jagrit Digani 2023-11-29 10:52:08 -08:00
parent d1f86272a2
commit e6306cfee9
74 changed files with 15964 additions and 2 deletions

View File

@ -1,2 +1,61 @@
# mlx # MLX
MLX: An array framework for Apple silicon
MLX is an array framework for machine learning specifically targeting Apple
Silicon. MLX is designed with inspiration from Jax, PyTorch, ArrayFire.
[Documentation](https://at.apple.com/mlx)
## Build
```
mkdir -p build && cd build
cmake .. && make -j
```
Run the C++ tests with `make test` (or `./tests/tests` for more detailed output).
### Python bidings
To install run:
`
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
`
For developing use an editable install:
```
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
```
To make sure the install is working run the tests with:
```
python -m unittest discover python/tests
```
## Develop
- Fork and submit pull requests to the repo.
- Every PR should have passing tests and at least one review.
- If a change is likely to impact efficiency, run some of the benchmarks before
and after the change. Examples of benchmarks can be found in `benchmarks/cpp/`.
- Install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
This should install hooks for running `black` and `clang-format` to ensure
consistent style for C++ and python code.
You can also run the formatters manually as follows:
```
clang-format -i file.cpp
```
```
black file.py
```
or run `pre-commit run --all-files` to check all files in the repo.

View File

@ -0,0 +1,11 @@
function(build_benchmark SRCFILE)
get_filename_component(src_name ${SRCFILE} NAME_WE)
set(target "${src_name}")
add_executable(${target} ${SRCFILE})
target_link_libraries(${target} PRIVATE mlx)
endfunction(build_benchmark)
build_benchmark(single_ops.cpp)
build_benchmark(irregular_strides.cpp)
build_benchmark(compare_devices.cpp)
build_benchmark(autograd.cpp)

View File

@ -0,0 +1,37 @@
#include <iostream>
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
void time_value_and_grad() {
auto x = ones({200, 1000});
eval(x);
auto fn = [](array x) {
for (int i = 0; i < 20; ++i) {
x = log(exp(x));
}
return sum(x);
};
auto grad_fn = grad(fn);
auto independent_value_and_grad = [&]() {
auto value = fn(x);
auto dfdx = grad_fn(x);
return std::vector<array>{value, dfdx};
};
TIME(independent_value_and_grad);
auto value_and_grad_fn = value_and_grad(fn);
auto combined_value_and_grad = [&]() {
auto [value, dfdx] = value_and_grad_fn(x);
return std::vector<array>{value, dfdx};
};
TIME(combined_value_and_grad);
}
int main() {
std::cout << "Benchmarks for " << default_device() << std::endl;
time_value_and_grad();
}

View File

@ -0,0 +1,25 @@
#include <iostream>
#include "mlx/mlx.h"
#include "time_utils.h"
using namespace mlx::core;
void time_add_op() {
std::vector<int> sizes(1, 1);
for (int i = 0; i < 9; ++i) {
sizes.push_back(10 * sizes.back());
}
set_default_device(Device::cpu);
for (auto size : sizes) {
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
std::cout << "Size " << size << std::endl;
TIMEM("cpu", add, a, b, Device::cpu);
TIMEM("gpu", add, a, b, Device::gpu);
}
}
int main() {
time_add_op();
}

View File

@ -0,0 +1,38 @@
import numpy as np
from time_utils import time_fn
def time_add():
a = np.ones((100, 100, 10), dtype=np.float32)
b = np.ones((100, 100, 10), dtype=np.float32)
time_fn(np.add, a, b)
def time_matmul():
a = np.random.rand(1000, 500).astype(np.float32)
b = np.random.rand(500, 1000).astype(np.float32)
time_fn(np.matmul, a, b)
def time_exp():
a = np.random.randn(1000, 100).astype(np.float32)
time_fn(np.exp, a)
def time_take():
a = np.random.rand(10000, 500)
ids = np.random.randint(0, 10000, (20, 10))
ids = [idx.reshape(-1) for idx in np.split(ids, 20)]
def random_take():
return [np.take(a, idx, 0) for idx in ids]
time_fn(random_take)
if __name__ == "__main__":
time_add()
time_matmul()
time_exp()
time_take()

View File

@ -0,0 +1,18 @@
import time
def time_fn(fn, *args):
print(f"Timing {fn.__name__} ...", end=" ")
# warmup
for _ in range(5):
fn(*args)
num_iters = 100
tic = time.perf_counter()
for _ in range(num_iters):
x = fn(*args)
toc = time.perf_counter()
msec = 1e3 * (toc - tic) / num_iters
print(f"{msec:.5f} msec")

View File

@ -0,0 +1,190 @@
import numpy as np
import argparse
import mlx.core as mx
import time
import torch
import os
import math
import subprocess
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 8
N_iter_bench = 80
N_iter_func = 5
def bench(f, a, b):
for i in range(N_warmup):
f(a, b)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(a, b)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def gemm_nn_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b
ys.append(y)
mx.eval(ys)
return ys
def gemm_nt_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b.transpose((0, 2, 1))
ys.append(y)
mx.eval(ys)
return ys
def gemm_tn_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose((0, 2, 1)) @ b
ys.append(y)
mx.eval(ys)
return ys
def gemm_tt_mlx(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1))
ys.append(y)
mx.eval(ys)
return ys
@torch.no_grad()
def gemm_nn_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemm_nt_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a @ b.transpose(-1, -2)
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemm_tn_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose(-1, -2) @ b
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemm_tt_torch(a, b):
ys = []
for i in range(N_iter_func):
y = a.transpose(-1, -2) @ b.transpose(-1, -2)
ys.append(y)
torch.mps.synchronize()
return ys
def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M)
shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K)
a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype)
b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_pt = torch.from_numpy(a_np).to("mps")
b_pt = torch.from_numpy(b_np).to("mps")
torch.mps.synchronize()
f_mx = {
"nn": gemm_nn_mlx,
"nt": gemm_nt_mlx,
"tn": gemm_tn_mlx,
"tt": gemm_tt_mlx,
}[transpose]
f_pt = {
"nn": gemm_nn_torch,
"nt": gemm_nt_torch,
"tn": gemm_tn_torch,
"tt": gemm_tt_torch,
}[transpose]
time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx)
t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1)
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
np.float32
)
atol = 1e-5 if np_dtype == np.float32 else 1e-4
if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol):
print(
f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
)
return time_mlx, time_torch
def get_gflop_count(B, M, N, K):
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
(15, 1023, 1023, 1023),
(17, 1025, 1025, 1025),
)
for dtype in dtypes:
for transpose in transposes:
for B, M, N, K in shapes:
np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose)
gflop_count = get_gflop_count(B, M, N, K)
gflops_mx = gflop_count / (time_mlx)
gflops_pt = gflop_count / (time_torch)
diff = gflops_mx / gflops_pt - 1.0
print(
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
)
if gflops_pt >= 2.0 * gflops_mx:
print("ATTENTION ^^^^^^^")

View File

@ -0,0 +1,219 @@
import matplotlib.pyplot as plt
import numpy as np
import argparse
import mlx.core as mx
import time
import torch
import os
import subprocess
results_dir = "./results"
if not os.path.isdir(results_dir):
os.mkdir(results_dir)
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
device_name = device_name.decode("utf-8").strip("\n")
N_warmup = 5
N_iter_bench = 50
N_iter_func = 20
out_vec_sizes = [128, 512, 2048, 4096]
in_vec_sizes = [128, 512, 2048, 4096]
benchmark_vector_lens = []
benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2]
benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2]
benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2]
benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000]
benchmark_vector_lens.sort()
def bench(f, m, v):
for i in range(N_warmup):
f(m, v)
torch.mps.synchronize()
s = time.perf_counter_ns()
for i in range(N_iter_bench):
f(m, v)
e = time.perf_counter_ns()
return (e - s) * 1e-9
def gemv_mlx(m, v):
ys = []
for i in range(N_iter_func):
y = m @ v
ys.append(y)
mx.eval(ys)
return ys
def gemv_t_mlx(m, v):
ys = []
for i in range(N_iter_func):
y = v @ m
ys.append(y)
mx.eval(ys)
return ys
@torch.no_grad()
def gemv_torch(m, v):
ys = []
for i in range(N_iter_func):
y = m @ v
ys.append(y)
torch.mps.synchronize()
return ys
@torch.no_grad()
def gemv_t_torch(m, v):
ys = []
for i in range(N_iter_func):
y = v @ m
ys.append(y)
torch.mps.synchronize()
return ys
def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False):
shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len)
shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1)
mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype)
vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype)
mat_mlx = mx.array(mat_npy)
vec_mlx = mx.array(vec_npy)
mat_trc = torch.from_numpy(mat_npy).to("mps")
vec_trc = torch.from_numpy(vec_npy).to("mps")
torch.mps.synchronize()
time_torch = (
bench(gemv_t_torch, mat_trc, vec_trc)
if transpose
else bench(gemv_torch, mat_trc, vec_trc)
)
time_mlx = (
bench(gemv_t_mlx, mat_mlx, vec_mlx)
if transpose
else bench(gemv_mlx, mat_mlx, vec_mlx)
)
c_mlx = (
np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx)
)
c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy)
if not np.allclose(c_mlx, c_npy, atol=2e-5):
print(
f"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
)
return time_mlx, time_torch
def get_gflop_count(in_vec_len, out_vec_len):
return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float(
1024**3
)
def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len
item_size = 4 if np_dtype == np.float32 else 2
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []
pyt_gb_s = []
pyt_gflops = []
for out_vec_len in out_vector_lens:
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
mlx_gb_s.append(gbyte_size / time_mlx)
pyt_gb_s.append(gbyte_size / time_torch)
mlx_gflops.append(gflop_count / time_mlx)
pyt_gflops.append(gflop_count / time_torch)
if transpose:
title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}"
else:
title = f"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}"
ax.plot(out_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
ax.plot(out_vector_lens, pyt_gb_s, "tab:red", label="Torch")
ax.set_title(title)
ax.set(xlabel="out_vector_len", ylabel="Performance (GB/s)")
ax.legend()
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
np_dtype = getattr(np, dtype)
mlx_gb_s = []
mlx_gflops = []
pyt_gb_s = []
pyt_gflops = []
for in_vec_len in in_vector_lens:
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
mlx_gb_s.append(gbyte_size / time_mlx)
pyt_gb_s.append(gbyte_size / time_torch)
mlx_gflops.append(gflop_count / time_mlx)
pyt_gflops.append(gflop_count / time_torch)
if transpose:
title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])"
else:
title = f"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )"
ax.plot(in_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
ax.plot(in_vector_lens, pyt_gb_s, "tab:red", label="Torch")
ax.set_title(title)
ax.set(xlabel="in_vector_len", ylabel="Performance (GB/s)")
ax.legend()
for transpose in (False, True):
for dtype in ("float32", "float16"):
fig, axs = plt.subplots(
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
)
for i, in_vec_len in enumerate(in_vec_sizes):
bench_with_in_len(
axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose
)
for i, out_vec_len in enumerate(out_vec_sizes):
bench_with_out_len(
axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose
)
op_name = "gemv_t" if transpose else "gemv"
fig.suptitle(f"{device_name}: {dtype} {op_name}")
fig.savefig(
os.path.join(
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
)
)
plt.close(fig)

View File

@ -0,0 +1,116 @@
import math
import time
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, True)
self.query_proj = nn.Linear(dims, dims, False)
self.key_proj = nn.Linear(dims, dims, False)
self.value_proj = nn.Linear(dims, dims, False)
self.out_proj = nn.Linear(dims, dims, False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, D = queries.shape
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Dimensions are [batch x num heads x sequence x hidden dim]
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, False)
self.linear2 = nn.Linear(dims, mlp_dims, False)
self.linear3 = nn.Linear(mlp_dims, dims, False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
def measure(model, x, cache):
for i in range(5):
y, c = model(x, mask=None, cache=cache)
mx.eval(y, c)
start = time.time()
rs = []
for i in range(5):
y, c = model(x, mask=None, cache=cache)
rs.append((y, c))
mx.eval(rs)
end = time.time()
return (end - start) * 1000 / 5
if __name__ == "__main__":
H = 32
D = 4096
F = 43 * 256
C = 1000
mx.set_default_device(mx.gpu)
dtype = mx.float16
layer = LlamaEncoderLayer(D, F, H)
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
x = mx.random.normal([1, 1, D], dtype=dtype)
cache = [
mx.random.normal([1, H, C, D // H], dtype=dtype),
mx.random.normal([1, H, C, D // H], dtype=dtype),
]
mx.eval(x, cache)
T = measure(layer, x, cache)
print("Time per layer per token:", T, "ms")
print("Lower bound total time per token:", T * 32, "ms")

56
cmake/extension.cmake Normal file
View File

@ -0,0 +1,56 @@
include(CMakeParseArguments)
###############################################################################
# Build metal library
#
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
#
# Args:
# TARGET: Custom target to be added for the metal library
# TITLE: Name of the .metallib
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
# SOURCES: List of source files
# INCLUDE_DIRS: List of include dirs
# DEPS: List of depedency files (like headers)
#
macro(mlx_build_metallib)
# Parse args
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
cmake_parse_arguments(
MTLLIB
""
"${oneValueArgs}"
"${multiValueArgs}"
${ARGN}
)
# Set output
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
# Collect compile options
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
# Prepare metllib build command
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND xcrun -sdk macosx metal
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS}
${MTLLIB_SOURCES}
-o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
COMMAND_EXPAND_LISTS
COMMENT "Building ${MTLLIB_TITLE}.metallib"
VERBATIM
)
# Add metallib custom target
add_custom_target(
${MTLLIB_TARGET}
DEPENDS
${MTLLIB_BUILD_TARGET}
)
endmacro(mlx_build_metallib)

0
docs/.nojekyll Normal file
View File

1
docs/index.html Normal file
View File

@ -0,0 +1 @@
<meta http-equiv="refresh" content="0; url=./build/html/index.html" />

View File

@ -0,0 +1,19 @@
{{ fullname | escape | underline}}
.. currentmodule:: {{ module }}
.. autoclass:: {{ objname }}
{#{% block methods %}
{% if methods %}
.. rubric:: {{ _('Methods') }}
.. autosummary::
{% for item in methods %}
{%- if item not in inherited_members and item != '__init__' %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}#}

6
docs/src/cpp/ops.rst Normal file
View File

@ -0,0 +1,6 @@
.. _cpp_ops:
Operations
==========

948
docs/src/dev/extensions.rst Normal file
View File

@ -0,0 +1,948 @@
Developer Documentation
=======================
MLX provides a open and flexible backend to which users may add operations
and specialized implementations without much hassle. While the library supplies
efficient operations that can be used and composed for any number of
applications, there may arise cases where new functionalities or highly
optimized implementations are needed. For such cases, you may design and
implement your own operations that link to and build on top of :mod:`mlx.core`.
We will introduce the inner-workings of MLX and go over a simple example to
learn the steps involved in adding new operations to MLX with your own CPU
and GPU implementations.
Introducing the Example
-----------------------
Let's say that you would like an operation that takes in two arrays,
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
respectively, and then adds them together to get the result
``z = alpha * x + beta * y``. Well, you can very easily do that by just
writing out a function as follows:
.. code-block:: python
import mlx.core as mx
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
This function performs that operation while leaving the implementations and
differentiation to MLX.
However, you work with vector math libraries often and realize that the
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
You would really like the part of your applications that does this operation
on the CPU to be very fast - so you decide that you want it to rely on the
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
our assumptions on to you, let's also assume that you want to learn how add
your own implementation for the gradients of your new operation while going
over the ins-and-outs of the MLX framework.
Well, what a coincidence! You are in the right place. Over the course of this
example, we will learn:
* The structure of the MLX library from the frontend API to the backend implementations.
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
* How to implement your own GPU implementation using metal.
* How to add your own ``vjp`` and ``jvp``.
* How to build your implementations, link them to MLX, and bind them to python.
Operations and Primitives
-------------------------
In one sentence, operations in MLX build the computation graph, and primitives
provide the rules for evaluation and transformations of said graph. Let's start
by discussing operations in more detail.
Operations
^^^^^^^^^^^
Operations are the frontend functions that operate on arrays. They are defined
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
operations in the Python API (:ref:`ops`).
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
C++ API:
.. code-block:: C++
/**
* Scale and sum two vectors elementwise
* z = alpha * x + beta * y
*
* Follow numpy style broadcasting between x and y
* Inputs are upcasted to floats if needed
**/
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s = {} // Stream on which to schedule the operation
);
This operation itself can call other operations within it if needed. So, the
simplest way to go about implementing this operation would be do so in terms
of existing operations.
.. code-block:: C++
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Scale x and y on the provided stream
auto ax = multiply(array(alpha), x, s);
auto by = multiply(array(beta), y, s);
// Add and return
return add(ax, by, s);
}
However, as we discussed earlier, this is not our goal. The operations themselves
do not contain the implementations that act on the data, nor do they contain the
rules of transformations. Rather, they are an easy to use interface that build
on top of the building blocks we call :class:`Primitive`.
Primitives
^^^^^^^^^^^
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
defines how to create an output given a set of input :class:`array` . Further,
a :class:`Primitive` is a class that contains rules on how it is evaluated
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
``jvp``. These words on their own can be a bit abstract, so lets take a step
back and go to our example to give ourselves a more concrete image.
.. code-block:: C++
class Axpby : public Primitive {
public:
explicit Axpby(Stream stream, float alpha, float beta)
: Primitive(stream), alpha_(alpha), beta_(beta){};
/**
* A primitive must know how to evaluate itself on the CPU/GPU
* for the given inputs and populate the output array.
*
* To avoid unecessary allocations, the evaluation function
* is responsible for allocating space for the array.
*/
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
/** The Jacobian-vector product. */
array jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
/** The vector-Jacobian product. */
std::vector<array> vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) override;
/**
* The primitive must know how to vectorize itself accross
* the given axes. The output is a pair containing the array
* representing the vectorized computation and the axis which
* corresponds to the output vectorized dimension.
*/
std::pair<array, int> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
/** Print the primitive. */
void print(std::ostream& os) override {
os << "Axpby";
}
/** Equivalence check **/
bool is_equivalent(const Primitive& other) const override;
private:
float alpha_;
float beta_;
/** Fall back implementation for evaluation on CPU */
void eval(const std::vector<array>& inputs, array& out);
};
The :class:`Axpby` class derives from the base :class:`Primitive` class and
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
``beta`` as parameters. It then provides implementations of how the array ``out``
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
Using the Primitives
^^^^^^^^^^^^^^^^^^^^^
Operations can use this :class:`Primitive` to add a new :class:`array` to
the computation graph. An :class:`array` can be constructed by providing its
data type, shape, the :class:`Primitive` that computes it, and the
:class:`array` inputs that are passed to the primitive.
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
.. code-block:: C++
array axpby(
const array& x, // Input array x
const array& y, // Input array y
const float alpha, // Scaling factor for x
const float beta, // Scaling factor for y
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
) {
// Promote dtypes between x and y as needed
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
// Upcast to float32 for non-floating point inputs x and y
auto out_dtype = is_floating_point(promoted_dtype)
? promoted_dtype
: promote_types(promoted_dtype, float32);
// Cast x and y up to the determined dtype (on the same stream s)
auto x_casted = astype(x, out_dtype, s);
auto y_casted = astype(y, out_dtype, s);
// Broadcast the shapes of x and y (on the same stream s)
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
auto out_shape = broadcasted_inputs[0].shape();
// Construct the array as the output of the Axpby primitive
// with the broadcasted and upcasted arrays as inputs
return array(
/* const std::vector<int>& shape = */ out_shape,
/* Dtype dtype = */ out_dtype,
/* std::unique_ptr<Primitive> primitive = */
std::make_unique<Axpby>(to_stream(s), alpha, beta),
/* const std::vector<array>& inputs = */ broadcasted_inputs);
}
This operation now handles the following:
#. Upcast inputs and resolve the the output data type.
#. Broadcast the inputs and resolve the output shape.
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
#. Construct the output :class:`array` using the primitive and the inputs.
Implementing the Primitive
--------------------------
No computation happens when we call the operation alone. In effect, the
operation only builds the computation graph. When we evaluate the output
array, MLX schedules the execution of the computation graph, and calls
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
stream/device specified by the user.
.. warning::
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
no memory has been allocated for the output array. It falls on the implementation
of these functions to allocate memory as needed
Implementing the CPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Let's start by trying to implement a naive and generic version of
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
:class:`Axpby` earlier called :meth:`Axpby::eval`.
Our naive method will go over each element of the output array, find the
corresponding input elements of ``x`` and ``y`` and perform the operation
pointwise. This is captured in the templated function :meth:`axpby_impl`.
.. code-block:: C++
template <typename T>
void axpby_impl(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// We only allocate memory when we are ready to fill the output
// malloc_or_wait synchronously allocates available memory
// There may be a wait executed here if the allocation is requested
// under memory-pressured conditions
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Collect input and output data pointers
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
T* out_ptr = out.data<T>();
// Cast alpha and beta to the relevant types
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Do the elementwise operation for each output
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
// Map linear indices to offsets in x and y
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
// We allocate the output to be contiguous and regularly strided
// (defaults to row major) and hence it doesn't need additonal mapping
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
}
}
Now, we would like our implementation to be able to do this pointwise operation
for all incoming floating point arrays. Accordingly, we add dispatches for
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
if we encounter an unexpected type.
.. code-block:: C++
/** Fall back implementation for evaluation on CPU */
void Axpby::eval(const std::vector<array>& inputs, array& out) {
// Check the inputs (registered in the op while contructing the out array)
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
// Dispatch to the correct dtype
if (out.dtype() == float32) {
return axpby_impl<float>(x, y, out, alpha_, beta_);
} else if (out.dtype() == float16) {
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == bfloat16) {
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
} else if (out.dtype() == complex64) {
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
} else {
throw std::runtime_error(
"Axpby is only supported for floating point types.");
}
}
We have a fallback implementation! Now, to do what we are really here to do.
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
framework? Well, there are 3 complications to keep in mind:
#. Accelerate does not provide implementations of ``axpby`` for half precision
floats. We can only direct to it for ``float32`` types
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
have fixed strides between them. Possibly due to broadcasts and transposes,
we aren't guaranteed that the inputs fit this requirement. We can
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
column contiguous.
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
MLX expects to write out the answer to a new array. We must copy the elements
of ``y`` into the output array and use that as an input to ``axpby``
Let's write out an implementation that uses Accelerate in the right conditions.
It must simply allocate data for the output, copy elements of ``y`` into it,
and then call the :meth:`catlas_saxpby` from accelerate.
.. code-block:: C++
template <typename T>
void axpby_impl_accelerate(
const array& x,
const array& y,
array& out,
float alpha_,
float beta_) {
// Accelerate library provides catlas_saxpby which does
// Y = (alpha * X) + (beta * Y) in place
// To use it, we first copy the data in y over to the output array
// This specialization requires both x and y be contiguous in the same mode
// i.e: corresponding linear indices in both point to corresponding elements
// The data in the output array is allocated to match the strides in y
// such that x, y, and out are contiguous in the same mode and
// no transposition is needed
out.set_data(
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
y.data_size(),
y.strides(),
y.flags());
// We then copy over the elements using the contiguous vector specialization
copy_inplace(y, out, CopyType::Vector);
// Get x and y pointers for catlas_saxpby
const T* x_ptr = x.data<T>();
T* y_ptr = out.data<T>();
T alpha = static_cast<T>(alpha_);
T beta = static_cast<T>(beta_);
// Call the inplace accelerate operator
catlas_saxpby(
/* N = */ out.size(),
/* ALPHA = */ alpha,
/* X = */ x_ptr,
/* INCX = */ 1,
/* BETA = */ beta,
/* Y = */ y_ptr,
/* INCY = */ 1);
}
Great! But what about the inputs that do not fit the criteria for accelerate?
Luckily, we can always just direct back to :meth:`Axpby::eval`.
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
.. code-block:: C++
/** Evaluate primitive on CPU using accelerate specializations */
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
// Accelerate specialization for contiguous single precision float arrays
if (out.dtype() == float32 &&
((x.flags().row_contiguous && y.flags().row_contiguous) ||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
return;
}
// Fall back to common backend if specializations are not available
eval(inputs, out);
}
We have now hit a milestone! Just this much is enough to run the operation
:meth:`axpby` on a CPU stream!
If you do not plan on running the operation on the GPU or using transforms on
computation graphs that contain :class:`Axpby`, you can stop implementing the
primitive here and enjoy the speed-ups you get from the Accelerate library.
Implementing the GPU Backend
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Apple silicon devices address their GPUs using the Metal_ shading language, and
all GPU kernels in MLX are written using metal.
.. note::
Here are some helpful resources if you are new to metal!
* A walkthrough of the metal compute pipeline: `Metal Example`_
* Documentation for metal shading language: `Metal Specification`_
* Using metal from C++: `Metal-cpp`_
Let's keep the GPU algorithm simple. We will launch exactly as many threads
as there are elements in the output. Each thread will pick the element it needs
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
element in the output.
.. code-block:: C++
template <typename T>
[[kernel]] void axpby_general(
device const T* x [[buffer(0)]],
device const T* y [[buffer(1)]],
device T* out [[buffer(2)]],
constant const float& alpha [[buffer(3)]],
constant const float& beta [[buffer(4)]],
constant const int* shape [[buffer(5)]],
constant const size_t* x_strides [[buffer(6)]],
constant const size_t* y_strides [[buffer(7)]],
constant const int& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
// Convert linear indices to offsets in array
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
// Do the operation and update the output
out[index] =
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
}
We then need to instantiate this template for all floating point types and give
each instantiation a unique host name so we can identify the right kernel for
each data type.
.. code-block:: C++
#define instantiate_axpby(type_name, type) \
template [[host_name("axpby_general_" #type_name)]] \
[[kernel]] void axpby_general<type>( \
device const type* x [[buffer(0)]], \
device const type* y [[buffer(1)]], \
device type* out [[buffer(2)]], \
constant const float& alpha [[buffer(3)]], \
constant const float& beta [[buffer(4)]], \
constant const int* shape [[buffer(5)]], \
constant const size_t* x_strides [[buffer(6)]], \
constant const size_t* y_strides [[buffer(7)]], \
constant const int& ndim [[buffer(8)]], \
uint index [[thread_position_in_grid]]);
instantiate_axpby(float32, float);
instantiate_axpby(float16, half);
instantiate_axpby(bflot16, bfloat16_t);
instantiate_axpby(complex64, complex64_t);
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
will see later in :ref:`Building with CMake`. In the following example, we
assume that the library ``mlx_ext.metallib`` will always be co-located with
the executable/ shared-library calling the :meth:`register_library` function.
The :meth:`register_library` function takes the library's name and potential
path (or in this case, a function that can produce the path of the metal
library) and tries to load that library if it hasn't already been registered
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
it is important to package your C++ library with the metal library. We will
go over this process in more detail later.
The logic to determine the kernel, set the inputs, resolve the grid dimensions
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
below.
.. code-block:: C++
/** Evaluate primitive on GPU */
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
// Prepare inputs
assert(inputs.size() == 2);
auto& x = inputs[0];
auto& y = inputs[1];
// Each primitive carries the stream it should execute on
// and each stream carries its device identifiers
auto& s = stream();
// We get the needed metal device using the stream
auto& d = metal::device(s.device);
// Allocate output memory
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Resolve name of kernel (corresponds to axpby.metal)
std::ostringstream kname;
kname << "axpby_" << "general_" << type_to_name(out);
// Make sure the metal library is available and look for it
// in the same folder as this executable if needed
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
// Make a kernel from this metal library
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
// Prepare to encode kernel
auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel);
// Kernel parameters are registered with buffer indices corresponding to
// those in the kernel decelaration at axpby.metal
int ndim = out.ndim();
size_t nelem = out.size();
// Encode input arrays to kernel
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, y, 1);
// Encode output arrays to kernel
set_array_buffer(compute_encoder, out, 2);
// Encode alpha and beta
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
compute_encoder->setBytes(&beta_, sizeof(float), 4);
// Encode shape, strides and ndim
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
compute_encoder->setBytes(&ndim, sizeof(int), 8);
// We launch 1 thread for each input and make sure that the number of
// threads in any given threadgroup is not higher than the max allowed
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
// Fix the 3D size of each threadgroup (in terms of threads)
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
// Fix the 3D size of the launch grid (in terms of threads)
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
// Launch the grid with the given number of threads divded among
// the given threadgroups
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
A few things to note about MLX and metal before moving on. MLX keeps track
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
to give us the active metal compute command encoder instead of building a
new one and calling :meth:`compute_encoder->end_encoding` at the end.
MLX keeps adding kernels (compute pipelines) to the active command encoder
until some specified limit is hit or the compute encoder needs to be flushed
for synchronization. MLX also handles enqueuing and commiting the associated
command buffers as needed. We suggest taking a deeper dive into
:class:`metal::Device` if you would like to study this routine further.
Primitive Transforms
^^^^^^^^^^^^^^^^^^^^^
Now that we have come this far, let's also learn how to add implementations to
transformations in a :class:`Primitive`. These transformations can be built on
top of our operations, including the one we just defined now. Which then gives
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
.. code-block:: C++
/** The Jacobian-vector product. */
array Axpby::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
// Forward mode diff that pushes along the tangents
// The jvp transform on the the primitive can built with ops
// that are scheduled on the same stream as the primtive
// If argnums = {0}, we only push along x in which case the
// jvp is just the tangent scaled by alpha
// Similarly, if argnums = {1}, the jvp is just the tangent
// scaled by beta
if (argnums.size() > 1) {
auto scale = argnums[0] == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, tangents[0].dtype());
return multiply(scale_arr, tangents[0], stream());
}
// If, argnums = {0, 1}, we take contributions from both
// which gives us jvp = tangent_x * alpha + tangent_y * beta
else {
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
}
}
.. code-block:: C++
/** The vector-Jacobian product. */
std::vector<array> Axpby::vjp(
const std::vector<array>& primals,
const array& cotan,
const std::vector<int>& argnums) {
// Reverse mode diff
std::vector<array> vjps;
for (auto arg : argnums) {
auto scale = arg == 0 ? alpha_ : beta_;
auto scale_arr = array(scale, cotan.dtype());
vjps.push_back(multiply(scale_arr, cotan, stream()));
}
return vjps;
}
Finally, you need not have a transformation fully defined to start using your
own :class:`Primitive`.
.. code-block:: C++
/** Vectorize primitve along given axis */
std::pair<array, int> Axpby::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::runtime_error("Axpby has no vmap implementation.");
}
Building and Binding
--------------------
Let's look at the overall directory structure first.
| extensions
| ├── axpby
| │ ├── axpby.cpp
| │ ├── axpby.h
| │ └── axpby.metal
| ├── mlx_sample_extensions
| │ └── __init__.py
| ├── bindings.cpp
| ├── CMakeLists.txt
| └── setup.py
* ``extensions/axpby/`` defines the C++ extension library
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
associated python package
* ``extensions/bindings.cpp`` provides python bindings for our operation
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
python bindings
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
the python package
Binding to Python
^^^^^^^^^^^^^^^^^^
We use PyBind11_ to build a Python API for the C++ library. Since bindings
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
are already provided, adding our :meth:`axpby` becomes very simple!
.. code-block:: C++
PYBIND11_MODULE(mlx_sample_extensions, m) {
m.doc() = "Sample C++ and metal extensions for MLX";
m.def(
"axpby",
&axpby,
"x"_a,
"y"_a,
py::pos_only(),
"alpha"_a,
"beta"_a,
py::kw_only(),
"stream"_a = py::none(),
R"pbdoc(
Scale and sum two vectors elementwise
``z = alpha * x + beta * y``
Follows numpy style broadcasting between ``x`` and ``y``
Inputs are upcasted to floats if needed
Args:
x (array): Input array.
y (array): Input array.
alpha (float): Scaling factor for ``x``.
beta (float): Scaling factor for ``y``.
Returns:
array: ``alpha * x + beta * y``
)pbdoc");
}
Most of the complexity in the above example comes from additional bells and
whistles such as the literal names and doc-strings.
.. warning::
:mod:`mlx.core` needs to be imported before importing
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
ensure that the casters for :mod:`mlx.core` components like
:class:`mlx.core.array` are available.
.. _Building with CMake:
Building with CMake
^^^^^^^^^^^^^^^^^^^^
Building the C++ extension library itself is simple, it only requires that you
``find_package(MLX CONFIG)`` and then link it to your library.
.. code-block:: cmake
# Add library
add_library(mlx_ext)
# Add sources
target_sources(
mlx_ext
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
)
# Add include headers
target_include_directories(
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
)
# Link to mlx
target_link_libraries(mlx_ext PUBLIC mlx)
We also need to build the attached metal library. For convenience, we provide a
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
automatically imported with MLX package).
Here is what that looks like in practice!
.. code-block:: cmake
# Build metallib
if(MLX_BUILD_METAL)
mlx_build_metallib(
TARGET mlx_ext_metallib
TITLE mlx_ext
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
)
add_dependencies(
mlx_ext
mlx_ext_metallib
)
endif()
Finally, we build the Pybind11_ bindings
.. code-block:: cmake
pybind11_add_module(
mlx_sample_extensions
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
)
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
if(BUILD_SHARED_LIBS)
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
endif()
Building with ``setuptools``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Once we have set out the CMake build rules as described above, we can use the
build utilities defined in :mod:`mlx.extension` for a simple build process.
.. code-block:: python
from mlx import extension
from setuptools import setup
if __name__ == "__main__":
setup(
name="mlx_sample_extensions",
version="0.0.0",
description="Sample C++ and Metal extensions for MLX primitives.",
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
cmdclass={"build_ext": extension.CMakeBuild},
packages = ["mlx_sample_extensions"],
package_dir = {"": "mlx_sample_extensions"},
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
zip_safe=False,
python_requires=">=3.7",
)
.. note::
We treat ``extensions/mlx_sample_extensions`` as the package directory
even though it only contains a ``__init__.py`` to ensure the following:
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
* The C++ extension library and the metal library are co-located with the python
bindings and copied together if the package is installed
You can build inplace for development using
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
This will result in a directory structure as follows:
| extensions
| ├── mlx_sample_extensions
| │ ├── __init__.py
| │ ├── libmlx_ext.dylib # C++ extension library
| │ ├── mlx_ext.metallib # Metal library
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
| ...
When you try to install using the command ``python -m pip install .``
(in ``extensions/``), the package will be installed with the same strucutre as
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
copied along with the python binding since they are specified as ``package_data``.
Usage
-----
After installing the extension as described above, you should be able to simply
import the python package and play with it as you would any other MLX operation!
Let's looks at a simple script and it's results!
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
a = mx.ones((3, 4))
b = mx.ones((3, 4))
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
print(f"c shape: {c.shape}")
print(f"c dtype: {c.dtype}")
print(f"c correctness: {mx.all(c == 6.0).item()}")
Output:
.. code-block::
c shape: [3, 4]
c dtype: float32
c correctness: True
Results
^^^^^^^^^^^^^^^^
Let's run a quick benchmark and see how our new ``axpby`` operation compares
with the naive :meth:`simple_axpby` we defined at first on the CPU.
.. code-block:: python
import mlx.core as mx
from mlx_sample_extensions import axpby
import time
mx.set_default_device(mx.cpu)
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
return alpha * x + beta * y
M = 256
N = 512
x = mx.random.normal((M, N))
y = mx.random.normal((M, N))
alpha = 4.0
beta = 2.0
mx.eval((x, y))
def bench(f):
# Warm up
for i in range(100):
z = f(x, y, alpha, beta)
mx.eval(z)
# Timed run
s = time.time()
for i in range(5000):
z = f(x, y, alpha, beta)
mx.eval(z)
e = time.time()
return e - s
simple_time = bench(simple_axpby)
custom_time = bench(axpby)
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
Results:
.. code-block::
Simple axpby: 0.114 s | Custom axpby: 0.109 s
We see some modest improvements right away!
This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations such as :meth:`grad` and :meth:`simplify`!
Scripts
-------
.. admonition:: Download the code
The full example code is available in `mlx-examples <code>`_.
.. code: `TODO_LINK/extensions`_
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/

View File

@ -0,0 +1,77 @@
.. _linear_regression:
Linear Regression
-----------------
Let's implement a basic linear regression model as a starting point to
learn MLX. First import the core package and setup some problem metadata:
.. code-block:: python
import mlx.core as mx
num_features = 100
num_examples = 1_000
num_iters = 10_000 # iterations of SGD
lr = 0.01 # learning rate for SGD
We'll generate a synthetic dataset by:
1. Sampling the design matrix ``X``.
2. Sampling a ground truth parameter vector ``w_star``.
3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``.
.. code-block:: python
# True parameters
w_star = mx.random.normal((num_features,))
# Input examples (design matrix)
X = mx.random.normal((num_examples, num_features))
# Noisy labels
eps = 1e-2 * mx.random.normal((num_examples,))
y = X @ w_star + eps
We will use SGD to find the optimal weights. To start, define the squared loss
and get the gradient function of the loss with respect to the parameters.
.. code-block:: python
def loss_fn(w):
return 0.5 * mx.mean(mx.square(X @ w - y))
grad_fn = mx.grad(loss_fn)
Start the optimization by initializing the parameters ``w`` randomly. Then
repeatedly update the parameters for ``num_iters`` iterations.
.. code-block:: python
w = 1e-2 * mx.random.normal((num_features,))
for _ in range(num_iters):
grad = grad_fn(w)
w = w - lr * grad
mx.eval(w)
Finally, compute the loss of the learned parameters and verify that they are
close to the ground truth parameters.
.. code-block:: python
loss = loss_fn(w)
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
print(
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
)
# Should print something close to: Loss 0.00005, |w-w*| = 0.00364
Complete `linear regression
<https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py>`_
and `logistic regression
<https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py>`_
examples are available in the MLX GitHub repo.

View File

@ -0,0 +1,52 @@
.. _data_types:
:orphan:
Data Types
==========
.. currentmodule:: mlx.core
The default floating point type is ``float32`` and the default integer type is
``int32``. The table below shows supported values for :obj:`Dtype`.
.. list-table:: Supported Data Types
:widths: 5 3 20
:header-rows: 1
* - Type
- Bytes
- Description
* - ``bool_``
- 1
- Boolean (``True``, ``False``) data type
* - ``uint8``
- 1
- 8-bit unsigned integer
* - ``uint16``
- 2
- 16-bit unsigned integer
* - ``uint32``
- 4
- 32-bit unsigned integer
* - ``uint32``
- 8
- 32-bit unsigned integer
* - ``int8``
- 1
- 8-bit signed integer
* - ``int16``
- 2
- 16-bit signed integer
* - ``int32``
- 4
- 32-bit signed integer
* - ``int64``
- 8
- 64-bit signed integer
* - ``float16``
- 2
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
* - ``float32``
- 4
- 32-bit float

View File

@ -0,0 +1,17 @@
.. _devices_and_streams:
Devices and Streams
===================
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
Device
default_device
set_default_device
Stream
default_stream
new_stream
set_default_stream

View File

@ -0,0 +1,16 @@
.. _transforms:
Transforms
==========
.. currentmodule:: mlx.core
.. autosummary::
:toctree: _autosummary
eval
grad
value_and_grad
jvp
vjp
vmap

View File

@ -0,0 +1,21 @@
.. _utils:
Tree Utils
==========
In MLX we consider a python tree to be an arbitrarily nested collection of
dictionaries, lists and tuples without cycles. Functions in this module that
return python trees will be using the default python ``dict``, ``list`` and
``tuple`` but they can usually process objects that inherit from any of these.
.. note::
Dictionaries should have keys that are valid python identifiers.
.. currentmodule:: mlx.utils
.. autosummary::
:toctree: _autosummary
tree_flatten
tree_unflatten
tree_map

View File

@ -0,0 +1,10 @@
function(build_example SRCFILE)
get_filename_component(src_name ${SRCFILE} NAME_WE)
set(target "${src_name}")
add_executable(${target} ${SRCFILE})
target_link_libraries(${target} PRIVATE mlx)
endfunction(build_example)
build_example(tutorial.cpp)
build_example(linear_regression.cpp)
build_example(logistic_regression.cpp)

View File

@ -0,0 +1,52 @@
#include <chrono>
#include <cmath>
#include <iostream>
#include "mlx/mlx.h"
#include "timer.h"
/**
* An example of linear regression with MLX.
*/
using namespace mlx::core;
int main() {
int num_features = 100;
int num_examples = 1'000;
int num_iters = 10'000;
float learning_rate = 0.01;
// True parameters
auto w_star = random::normal({num_features});
// The input examples (design matrix)
auto X = random::normal({num_examples, num_features});
// Noisy labels
auto eps = 1e-2 * random::normal({num_examples});
auto y = matmul(X, w_star) + eps;
// Initialize random parameters
array w = 1e-2 * random::normal({num_features});
auto loss_fn = [&](array w) {
auto yhat = matmul(X, w);
return (0.5f / num_examples) * sum(square(yhat - y));
};
auto grad_fn = grad(loss_fn);
auto tic = timer::time();
for (int it = 0; it < num_iters; ++it) {
auto grad = grad_fn(w);
w = w - learning_rate * grad;
eval(w);
}
auto toc = timer::time();
auto loss = loss_fn(w);
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
auto throughput = num_iters / timer::seconds(toc - tic);
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
<< ", Throughput " << throughput << " (it/s)." << std::endl;
}

3581
mlx/3rdparty/pocketfft.h vendored Normal file

File diff suppressed because it is too large Load Diff

64
mlx/allocator.h Normal file
View File

@ -0,0 +1,64 @@
#pragma once
#include <cstdlib>
namespace mlx::core::allocator {
// Simple wrapper around buffer pointers
// WARNING: Only Buffer objects constructed from and those that wrap
// raw pointers from mlx::allocator are supported.
class Buffer {
private:
void* ptr_;
public:
Buffer(void* ptr) : ptr_(ptr){};
// Get the raw data pointer from the buffer
void* raw_ptr();
// Get the buffer pointer from the buffer
const void* ptr() const {
return ptr_;
};
void* ptr() {
return ptr_;
};
};
Buffer malloc(size_t size);
void free(Buffer buffer);
// Wait for running tasks to finish and free up memory
// if allocation fails
Buffer malloc_or_wait(size_t size);
class Allocator {
/** Abstract base clase for a memory allocator. */
public:
virtual Buffer malloc(size_t size) = 0;
virtual void free(Buffer buffer) = 0;
Allocator() = default;
Allocator(const Allocator& other) = delete;
Allocator(Allocator&& other) = delete;
Allocator& operator=(const Allocator& other) = delete;
Allocator& operator=(Allocator&& other) = delete;
virtual ~Allocator() = default;
};
Allocator& allocator();
class CommonAllocator : public Allocator {
/** A general CPU allocator. */
public:
virtual Buffer malloc(size_t size) override;
virtual void free(Buffer buffer) override;
private:
CommonAllocator() = default;
friend Allocator& allocator();
};
} // namespace mlx::core::allocator

View File

@ -0,0 +1,18 @@
#include <cassert>
#include <simd/vector.h>
#include <vecLib/vDSP.h>
#include "mlx/backend/common/copy.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core {
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
// TODO: Add accelerate based optimizations for CPU conv
}
} // namespace mlx::core

View File

@ -0,0 +1,26 @@
#pragma once
#include <vecLib/BNNS/bnns.h>
#include "mlx/dtype.h"
namespace mlx::core {
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
uint32_t size_bits = size_of(mlx_dtype) * 8;
switch (kindof(mlx_dtype)) {
case Dtype::Kind::b:
return BNNSDataTypeBoolean;
case Dtype::Kind::u:
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
case Dtype::Kind::i:
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
case Dtype::Kind::f:
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
case Dtype::Kind::V:
return BNNSDataTypeBFloat16;
case Dtype::Kind::c:
throw std::invalid_argument("BNNS does not support complex types");
}
}
} // namespace mlx::core

27
mlx/backend/common/copy.h Normal file
View File

@ -0,0 +1,27 @@
#pragma once
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
namespace mlx::core {
enum class CopyType {
// Copy a raw scalar input into the full contiguous output
Scalar,
// Copy the raw input buffer contiguously into a raw output buffer of the same
// size
Vector,
// Copy the full virtual input to the full contiguous output
General,
// Copy the full virtual input to the full virtual output. We assume the
// input and output have the same shape.
GeneralGeneral
};
void copy(const array& src, array& dst, CopyType ctype);
void copy_inplace(const array& src, array& dst, CopyType ctype);
} // namespace mlx::core

394
mlx/backend/common/sort.cpp Normal file
View File

@ -0,0 +1,394 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
#include "mlx/backend/common/copy.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <typename T, typename IdxT = int32_t>
struct StridedIterator {
using iterator_category = std::random_access_iterator_tag;
using difference_type = IdxT;
using value_type = T;
using reference = value_type&;
using pointer = value_type*;
// Constructors
StridedIterator() = default;
explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0)
: ptr_(ptr + offset * stride), stride_(stride) {}
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
// Accessors
reference operator*() const {
return ptr_[0];
}
reference operator[](difference_type idx) const {
return ptr_[idx * stride_];
}
// Comparisons
bool operator==(const StridedIterator& other) const {
return ptr_ == other.ptr_ && stride_ == other.stride_;
}
bool operator!=(const StridedIterator& other) const {
return ptr_ != other.ptr_;
}
bool operator<(const StridedIterator& other) const {
return ptr_ < other.ptr_;
}
bool operator>(const StridedIterator& other) const {
return ptr_ > other.ptr_;
}
bool operator<=(const StridedIterator& other) const {
return ptr_ <= other.ptr_;
}
bool operator>=(const StridedIterator& other) const {
return ptr_ >= other.ptr_;
}
difference_type operator-(const StridedIterator& other) const {
return (ptr_ - other.ptr_) / stride_;
}
// Moving
StridedIterator& operator++() {
ptr_ += stride_;
return *this;
}
StridedIterator& operator--() {
ptr_ -= stride_;
return *this;
}
StridedIterator& operator+=(difference_type diff) {
ptr_ += diff * stride_;
return *this;
}
StridedIterator& operator-=(difference_type diff) {
ptr_ -= diff * stride_;
return *this;
}
StridedIterator operator+(difference_type diff) {
return StridedIterator(ptr_, stride_, diff);
}
StridedIterator operator-(difference_type diff) {
return StridedIterator(ptr_, stride_, -diff);
}
private:
size_t stride_;
T* ptr_;
};
template <typename T, typename IdxT = uint32_t>
void sort(const array& in, array& out, int axis) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
// Perform sorting in place
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
T* data_ptr = out.data<T>() + loc;
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::stable_sort(st, ed);
}
}
template <typename T, typename IdxT = uint32_t>
void argsort(const array& in, array& out, int axis) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
// Perform sorting
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
const T* data_ptr = in.data<T>() + loc;
IdxT* idx_ptr = out.data<IdxT>() + loc;
StridedIterator st_(idx_ptr, axis_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, axis_stride, 0);
StridedIterator ed(idx_ptr, axis_stride, axis_size);
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride];
auto v2 = data_ptr[b * axis_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
}
template <typename T, typename IdxT = uint32_t>
void partition(const array& in, array& out, int axis, int kth) {
// Copy input to output
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
copy(in, out, ctype);
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition in place
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
T* data_ptr = out.data<T>() + loc;
StridedIterator st(data_ptr, axis_stride, 0);
StridedIterator md(data_ptr, axis_stride, kth);
StridedIterator ed(data_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed);
}
}
template <typename T, typename IdxT = uint32_t>
void argpartition(const array& in, array& out, int axis, int kth) {
// Allocate output
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Get axis, shape and stride info
axis = axis < 0 ? axis + in.ndim() : axis;
size_t n_rows = in.size() / in.shape(axis);
auto remaining_shape = in.shape();
remaining_shape.erase(remaining_shape.begin() + axis);
auto remaining_strides = in.strides();
remaining_strides.erase(remaining_strides.begin() + axis);
size_t axis_stride = in.strides()[axis];
int axis_size = in.shape(axis);
kth = kth < 0 ? kth + axis_size : kth;
// Perform partition
for (int i = 0; i < n_rows; i++) {
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
const T* data_ptr = in.data<T>() + loc;
IdxT* idx_ptr = out.data<IdxT>() + loc;
StridedIterator st_(idx_ptr, axis_stride, 0);
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
// Initialize with iota
std::iota(st_, ed_, IdxT(0));
// Sort according to vals
StridedIterator st(idx_ptr, axis_stride, 0);
StridedIterator md(idx_ptr, axis_stride, kth);
StridedIterator ed(idx_ptr, axis_stride, axis_size);
std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
auto v1 = data_ptr[a * axis_stride];
auto v2 = data_ptr[b * axis_stride];
return v1 < v2 || (v1 == v2 && a < b);
});
}
}
} // namespace
void ArgSort::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return argsort<bool>(in, out, axis_);
case uint8:
return argsort<uint8_t>(in, out, axis_);
case uint16:
return argsort<uint16_t>(in, out, axis_);
case uint32:
return argsort<uint32_t>(in, out, axis_);
case uint64:
return argsort<uint64_t>(in, out, axis_);
case int8:
return argsort<int8_t>(in, out, axis_);
case int16:
return argsort<int16_t>(in, out, axis_);
case int32:
return argsort<int32_t>(in, out, axis_);
case int64:
return argsort<int64_t>(in, out, axis_);
case float32:
return argsort<float>(in, out, axis_);
case float16:
return argsort<float16_t>(in, out, axis_);
case bfloat16:
return argsort<bfloat16_t>(in, out, axis_);
case complex64:
return argsort<complex64_t>(in, out, axis_);
}
}
void Sort::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return sort<bool>(in, out, axis_);
case uint8:
return sort<uint8_t>(in, out, axis_);
case uint16:
return sort<uint16_t>(in, out, axis_);
case uint32:
return sort<uint32_t>(in, out, axis_);
case uint64:
return sort<uint64_t>(in, out, axis_);
case int8:
return sort<int8_t>(in, out, axis_);
case int16:
return sort<int16_t>(in, out, axis_);
case int32:
return sort<int32_t>(in, out, axis_);
case int64:
return sort<int64_t>(in, out, axis_);
case float32:
return sort<float>(in, out, axis_);
case float16:
return sort<float16_t>(in, out, axis_);
case bfloat16:
return sort<bfloat16_t>(in, out, axis_);
case complex64:
return sort<complex64_t>(in, out, axis_);
}
}
void ArgPartition::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return argpartition<bool>(in, out, axis_, kth_);
case uint8:
return argpartition<uint8_t>(in, out, axis_, kth_);
case uint16:
return argpartition<uint16_t>(in, out, axis_, kth_);
case uint32:
return argpartition<uint32_t>(in, out, axis_, kth_);
case uint64:
return argpartition<uint64_t>(in, out, axis_, kth_);
case int8:
return argpartition<int8_t>(in, out, axis_, kth_);
case int16:
return argpartition<int16_t>(in, out, axis_, kth_);
case int32:
return argpartition<int32_t>(in, out, axis_, kth_);
case int64:
return argpartition<int64_t>(in, out, axis_, kth_);
case float32:
return argpartition<float>(in, out, axis_, kth_);
case float16:
return argpartition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return argpartition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return argpartition<complex64_t>(in, out, axis_, kth_);
}
}
void Partition::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
switch (in.dtype()) {
case bool_:
return partition<bool>(in, out, axis_, kth_);
case uint8:
return partition<uint8_t>(in, out, axis_, kth_);
case uint16:
return partition<uint16_t>(in, out, axis_, kth_);
case uint32:
return partition<uint32_t>(in, out, axis_, kth_);
case uint64:
return partition<uint64_t>(in, out, axis_, kth_);
case int8:
return partition<int8_t>(in, out, axis_, kth_);
case int16:
return partition<int16_t>(in, out, axis_, kth_);
case int32:
return partition<int32_t>(in, out, axis_, kth_);
case int64:
return partition<int64_t>(in, out, axis_, kth_);
case float32:
return partition<float>(in, out, axis_, kth_);
case float16:
return partition<float16_t>(in, out, axis_, kth_);
case bfloat16:
return partition<bfloat16_t>(in, out, axis_, kth_);
case complex64:
return partition<complex64_t>(in, out, axis_, kth_);
}
}
} // namespace mlx::core

View File

@ -0,0 +1,257 @@
#include <dlfcn.h>
#include <cstdlib>
#include <filesystem>
#include <sstream>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal.h"
#include "mlx/backend/metal/mps/gemm.h"
namespace fs = std::filesystem;
namespace mlx::core::metal {
static Device metal_device_;
namespace {
// TODO nicer way to set this or possibly expose as an environment variable
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
static constexpr const char* default_mtllib_path = METAL_PATH;
auto load_device() {
MTL::Device* device = MTL::CreateSystemDefaultDevice();
if (!device) {
throw std::runtime_error("Failed to load device");
}
return device;
}
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
MTL::Device* device,
const char* path) {
auto library = NS::String::string(path, NS::UTF8StringEncoding);
NS::Error* error;
auto lib = device->newLibrary(library, &error);
return std::make_pair(lib, error);
}
MTL::Library* load_library(
MTL::Device* device,
const std::string& lib_name = "mlx",
const char* lib_path = default_mtllib_path) {
// Firstly, search for the metallib in the same path as this binary
std::string first_path = get_colocated_mtllib_path(lib_name);
if (first_path.size() != 0) {
auto [lib, error] = load_library_from_path(device, first_path.c_str());
if (lib) {
return lib;
}
}
// Couldn't find it so let's load it from default_mtllib_path
{
auto [lib, error] = load_library_from_path(device, lib_path);
if (!lib) {
std::ostringstream msg;
msg << error->localizedDescription()->utf8String() << "\n"
<< "Failed to load device library from <" << lib_path << ">"
<< " or <" << first_path << ">.";
throw std::runtime_error(msg.str());
}
return lib;
}
}
} // namespace
Device::Device()
: pool_(NS::AutoreleasePool::alloc()->init()),
device_(load_device()),
library_map_({{"mlx", load_library(device_)}}) {}
Device::~Device() {
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& k : kernel_map_) {
k.second->release();
}
for (auto& l : library_map_) {
l.second->release();
}
device_->release();
pool_->release();
}
void Device::new_queue(int index) {
// Multiple threads can ask the device for queues
// We lock this as a critical section for safety
const std::lock_guard<std::mutex> lock(mtx_);
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
if (!q) {
throw std::runtime_error(
"[metal::Device] Failed to make new command queue.");
}
queue_map_.insert({index, q});
}
int Device::get_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
return bit->second.first;
}
void Device::increment_command_buffer_ops(int index) {
auto bit = buffer_map_.find(index);
bit->second.first++;
}
MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index);
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
}
MTL::CommandBuffer* Device::new_command_buffer(int index) {
auto qit = queue_map_.find(index);
if (qit == queue_map_.end()) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error(
"[metal::Device] Unable to create new command buffer");
}
// Increment ref count so the buffer is not garbage collected
cb->retain();
return buffer_map_.insert({index, {0, cb}}).first->second.second;
}
void Device::commit_command_buffer(int index) {
auto bit = buffer_map_.find(index);
bit->second.second->commit();
bit->second.second->release();
buffer_map_.erase(bit);
}
void Device::end_encoding(int index) {
auto eit = encoder_map_.find(index);
if (eit != encoder_map_.end()) {
eit->second->endEncoding();
eit->second->release();
encoder_map_.erase(eit);
}
}
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index);
if (eit == encoder_map_.end()) {
auto cb = get_command_buffer(index);
auto compute_encoder = cb->computeCommandEncoder();
// Increment ref count so the buffer is not garbage collected
compute_encoder->retain();
eit = encoder_map_.insert({index, compute_encoder}).first;
}
return eit->second;
}
MTL::ArgumentEncoder* Device::argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
// NB array here is already autoreleased but the returned argument
// encoder is owned by the caller and must be released/autoreleased
NS::Array* arg_desc_arr = NS::Array::array(
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
return device_->newArgumentEncoder(arg_desc_arr);
}
void Device::register_library(
const std::string& lib_name,
const std::string& lib_path) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
library_map_.insert({lib_name, new_lib});
}
}
void Device::register_library(
const std::string& lib_name,
const std::function<std::string(const std::string&)>& lib_path_func) {
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
std::string new_lib_path = lib_path_func(lib_name);
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
library_map_.insert({lib_name, new_lib});
}
}
MTL::ComputePipelineState* Device::get_kernel(
const std::string& name,
const std::string& lib_name /* = "mlx" */) {
// Look for cached kernel
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
return it->second;
}
// Prepare new kernel
// Search for cached metal lib
MTL::Library* mtl_lib;
if (auto it = library_map_.find(name); it != library_map_.end()) {
mtl_lib = it->second;
} else { // Look for metallib alongside library
register_library(lib_name);
mtl_lib = library_map_[lib_name];
}
// Pull kernel from library
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
auto mtl_function = mtl_lib->newFunction(ns_name);
// Compile kernel to compute pipeline
NS::Error* error = nullptr;
MTL::ComputePipelineState* kernel;
if (mtl_function) {
kernel = device_->newComputePipelineState(mtl_function, &error);
mtl_function->release();
}
if (!mtl_function || !kernel) {
std::ostringstream msg;
msg << "[metal::Device] Unable to load kernel " << name << "\n";
if (error) {
msg << error->localizedDescription()->utf8String() << "\n";
}
throw std::runtime_error(msg.str());
}
// Add kernel to cache
kernel_map_.insert({name, kernel});
return kernel;
}
Device& device(mlx::core::Device) {
return metal_device_;
}
NS::AutoreleasePool*& thread_autorelease_pool() {
static thread_local NS::AutoreleasePool* p =
NS::AutoreleasePool::alloc()->init();
return p;
}
void new_stream(Stream stream) {
thread_autorelease_pool();
if (stream.device == mlx::core::Device::gpu) {
device(stream.device).new_queue(stream.index);
}
}
} // namespace mlx::core::metal

10
mlx/backend/metal/fft.cpp Normal file
View File

@ -0,0 +1,10 @@
#include "mlx/primitives.h"
namespace mlx::core {
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
throw std::runtime_error("[FFT] NYI for Metal backend.");
}
} // namespace mlx::core

View File

@ -0,0 +1,83 @@
set(
HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
)
set(
KERNELS
"arange"
"arg_reduce"
"binary"
"conv"
"copy"
"gemm"
"gemv"
"random"
"reduce"
"scan"
"softmax"
"sort"
"unary"
"indexing"
)
function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
set(HEADERS_PADDED ${HEADERS})
if(${KERNEL} STREQUAL "gemm")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h)
endif()
if(${KERNEL} STREQUAL "conv")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h)
endif()
add_custom_command(
COMMAND xcrun -sdk macosx metal -Wall -Wextra
-fno-fast-math
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${KERNEL}.air
DEPENDS ${SRCFILE} ${HEADERS_PADDED}
OUTPUT ${KERNEL}.air
COMMENT "Building ${KERNEL}.air"
VERBATIM
)
endfunction(build_kernel)
foreach(KERNEL ${KERNELS})
build_kernel(${KERNEL})
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
endforeach()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib"
VERBATIM
)
add_custom_target(
mlx-metallib
DEPENDS
${MLX_METAL_PATH}/mlx.metallib
)
add_dependencies(
mlx
mlx-metallib
)
# Install metallib
include(GNUInstallDirs)
install(
FILES ${MLX_METAL_PATH}/mlx.metallib
DESTINATION ${CMAKE_INSTALL_LIBDIR}
COMPONENT metallib
)

View File

@ -0,0 +1,17 @@
#pragma once
template <int NDIM>
struct MLXConvParams {
const int N; // Batch size
const int C; // In channels
const int O; // Out channels
const int iS[NDIM]; // Input spatial dim
const int wS[NDIM]; // Weight spatial dim
const int oS[NDIM]; // Output spatial dim
const int str[NDIM]; // Kernel strides
const int pad[NDIM]; // Input padding
const int dil[NDIM]; // Kernel dilation
const size_t in_strides[NDIM + 2]; // In strides
const size_t wt_strides[NDIM + 2]; // Wt strides
const size_t out_strides[NDIM + 2]; // Out strides
};

View File

@ -0,0 +1,253 @@
#include <metal_atomic>
#include <metal_texture>
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/reduce.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
/////////////////////////////////////////////////////////////////////
// Gather kernel
/////////////////////////////////////////////////////////////////////
template <typename IdxT, int NIDX>
struct Indices {
const array<device IdxT*, NIDX> buffers [[id(0)]];
device int* shapes [[id(NIDX + 1)]];
device size_t* strides [[id(NIDX + 2)]];
const int ndim [[id(NIDX + 3)]];
};
template <typename IdxT>
inline size_t offset_neg_idx(IdxT idx, size_t size) {
return (idx < 0) ? idx + size : idx;
}
template <>
inline size_t offset_neg_idx(bool idx, size_t) {
return idx;
}
template <>
inline size_t offset_neg_idx(uint32_t idx, size_t) {
return idx;
}
template <typename T, typename IdxT, int NIDX>
[[kernel]] void gather(
const device T *src [[buffer(0)]],
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
device T *out [[buffer(2)]],
const device int *src_shape [[buffer(3)]],
const device size_t *src_strides [[buffer(4)]],
const device size_t& src_ndim [[buffer(5)]],
const device int *slice_sizes [[buffer(6)]],
const device size_t& slice_size [[buffer(7)]],
const device int *axes [[buffer(8)]],
uint gid [[thread_position_in_grid]]) {
auto ind_idx = gid / slice_size;
auto ind_offset = gid % slice_size;
size_t src_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], src_shape[ax]);
src_idx += idx_val * src_strides[ax];
}
auto src_offset = elem_to_loc(
ind_offset, slice_sizes, src_strides, src_ndim);
out[gid] = src[src_idx + src_offset];
}
#define instantiate_gather4(name, src_type, ind_type, nindex) \
template [[host_name("gather" name "_" #nindex)]] \
[[kernel]] void gather( \
const device src_type *src [[buffer(0)]], \
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
device src_type *out [[buffer(2)]], \
const device int *src_shape [[buffer(3)]], \
const device size_t *src_strides [[buffer(4)]], \
const device size_t& src_ndim [[buffer(5)]], \
const device int *slice_sizes [[buffer(6)]], \
const device size_t& slice_size [[buffer(7)]], \
const device int* axes [[buffer(8)]], \
uint gid [[thread_position_in_grid]]);
// Special for case NIDX=0
instantiate_gather4("bool_", bool, bool, 0)
instantiate_gather4("uint8", uint8_t, bool, 0)
instantiate_gather4("uint16", uint16_t, bool, 0)
instantiate_gather4("uint32", uint32_t, bool, 0)
instantiate_gather4("uint64", uint64_t, bool, 0)
instantiate_gather4("int8", int8_t, bool, 0)
instantiate_gather4("int16", int16_t, bool, 0)
instantiate_gather4("int32", int32_t, bool, 0)
instantiate_gather4("int64", int64_t, bool, 0)
instantiate_gather4("float16", half, bool, 0)
instantiate_gather4("float32", float, bool, 0)
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
#define instantiate_gather3(name, src_type, ind_type) \
instantiate_gather4(name, src_type, ind_type, 1) \
instantiate_gather4(name, src_type, ind_type, 2) \
instantiate_gather4(name, src_type, ind_type, 3) \
instantiate_gather4(name, src_type, ind_type, 4) \
instantiate_gather4(name, src_type, ind_type, 5) \
instantiate_gather4(name, src_type, ind_type, 6) \
instantiate_gather4(name, src_type, ind_type, 7) \
instantiate_gather4(name, src_type, ind_type, 8) \
instantiate_gather4(name, src_type, ind_type, 9) \
instantiate_gather4(name, src_type, ind_type, 10)
#define instantiate_gather(name, src_type) \
instantiate_gather3(#name "bool_", src_type, bool) \
instantiate_gather3(#name "uint8", src_type, uint8_t) \
instantiate_gather3(#name "uint16", src_type, uint16_t) \
instantiate_gather3(#name "uint32", src_type, uint32_t) \
instantiate_gather3(#name "uint64", src_type, uint64_t) \
instantiate_gather3(#name "int8", src_type, int8_t) \
instantiate_gather3(#name "int16", src_type, int16_t) \
instantiate_gather3(#name "int32", src_type, int32_t) \
instantiate_gather3(#name "int64", src_type, int64_t)
instantiate_gather(bool_, bool)
instantiate_gather(uint8, uint8_t)
instantiate_gather(uint16, uint16_t)
instantiate_gather(uint32, uint32_t)
instantiate_gather(uint64, uint64_t)
instantiate_gather(int8, int8_t)
instantiate_gather(int16, int16_t)
instantiate_gather(int32, int32_t)
instantiate_gather(int64, int64_t)
instantiate_gather(float16, half)
instantiate_gather(float32, float)
instantiate_gather(bfloat16, bfloat16_t)
/////////////////////////////////////////////////////////////////////
// Scatter kernel
/////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX>
[[kernel]] void scatter(
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const device int *upd_shape [[buffer(3)]],
const device size_t *upd_strides [[buffer(4)]],
const device size_t& upd_ndim [[buffer(5)]],
const device size_t& upd_size [[buffer(6)]],
const device int *out_shape [[buffer(7)]],
const device size_t *out_strides [[buffer(8)]],
const device size_t& out_ndim [[buffer(9)]],
const device int* axes [[buffer(10)]],
uint gid [[thread_position_in_grid]]) {
Op op;
auto ind_idx = gid / upd_size;
auto ind_offset = gid % upd_size;
size_t out_idx = 0;
for (int i = 0; i < NIDX; ++i) {
auto idx_loc = elem_to_loc(
ind_idx,
&indices.shapes[indices.ndim * i],
&indices.strides[indices.ndim * i],
indices.ndim);
auto ax = axes[i];
auto idx_val = offset_neg_idx(
indices.buffers[i][idx_loc], out_shape[ax]);
out_idx += idx_val * out_strides[ax];
}
auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out + out_idx + out_offset, updates[upd_idx]);
}
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
template [[host_name("scatter" name "_" #nindex)]] \
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
const device type *updates [[buffer(1)]], \
device mlx_atomic<type> *out [[buffer(2)]], \
const device int *upd_shape [[buffer(3)]], \
const device size_t *upd_strides [[buffer(4)]], \
const device size_t& upd_ndim [[buffer(5)]], \
const device size_t& upd_size [[buffer(6)]], \
const device int *out_shape [[buffer(7)]], \
const device size_t *out_strides [[buffer(8)]], \
const device size_t& out_ndim [[buffer(9)]], \
const device int* axes [[buffer(10)]], \
uint gid [[thread_position_in_grid]]);
// Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \
instantiate_scatter4(#name "none", type, bool, None, 0) \
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
#define instantiate_scatter3(name, type, ind_type, op_type) \
instantiate_scatter4(name, type, ind_type, op_type, 1) \
instantiate_scatter4(name, type, ind_type, op_type, 2) \
instantiate_scatter4(name, type, ind_type, op_type, 3) \
instantiate_scatter4(name, type, ind_type, op_type, 4) \
instantiate_scatter4(name, type, ind_type, op_type, 5) \
instantiate_scatter4(name, type, ind_type, op_type, 6) \
instantiate_scatter4(name, type, ind_type, op_type, 7) \
instantiate_scatter4(name, type, ind_type, op_type, 8) \
instantiate_scatter4(name, type, ind_type, op_type, 9) \
instantiate_scatter4(name, type, ind_type, op_type, 10)
#define instantiate_scatter2(name, type, ind_type) \
instantiate_scatter3(name "_none", type, ind_type, None) \
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
#define instantiate_scatter(name, type) \
instantiate_scatter2(#name "bool_", type, bool) \
instantiate_scatter2(#name "uint8", type, uint8_t) \
instantiate_scatter2(#name "uint16", type, uint16_t) \
instantiate_scatter2(#name "uint32", type, uint32_t) \
instantiate_scatter2(#name "uint64", type, uint64_t) \
instantiate_scatter2(#name "int8", type, int8_t) \
instantiate_scatter2(#name "int16", type, int16_t) \
instantiate_scatter2(#name "int32", type, int32_t) \
instantiate_scatter2(#name "int64", type, int64_t)
// TODO uint64 and int64 unsupported
instantiate_scatter_nd0(bool_, bool)
instantiate_scatter_nd0(uint8, uint8_t)
instantiate_scatter_nd0(uint16, uint16_t)
instantiate_scatter_nd0(uint32, uint32_t)
instantiate_scatter_nd0(int8, int8_t)
instantiate_scatter_nd0(int16, int16_t)
instantiate_scatter_nd0(int32, int32_t)
instantiate_scatter_nd0(float16, half)
instantiate_scatter_nd0(float32, float)
instantiate_scatter_nd0(bfloat16, bfloat16_t)
instantiate_scatter(bool_, bool)
instantiate_scatter(uint8, uint8_t)
instantiate_scatter(uint16, uint16_t)
instantiate_scatter(uint32, uint32_t)
instantiate_scatter(int8, int8_t)
instantiate_scatter(int16, int16_t)
instantiate_scatter(int32, int32_t)
instantiate_scatter(float16, half)
instantiate_scatter(float32, float)
instantiate_scatter(bfloat16, bfloat16_t)

View File

@ -0,0 +1,174 @@
#pragma once
#include <metal_atomic>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/atomic.h"
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/utils.h"
union bool4_or_uint {
bool4 b;
unsigned int i;
};
struct None {
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_store_explicit(out, val, offset);
}
};
struct And {
bool simd_reduce(bool val) {
return simd_all(val);
};
static constexpr constant bool init = true;
void atomic_update(
device mlx_atomic<unsigned int>* out,
bool val,
int elem_idx,
int offset = 0) {
if (!val) {
bool4_or_uint update;
update.b = {true, true, true, true};
update.b[elem_idx] = false;
mlx_atomic_fetch_and_explicit(out, update.i, offset);
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
if (!val) {
mlx_atomic_store_explicit(out, val, offset);
}
}
// Non atomic update
void update(device bool* out, bool val) {
*out &= val;
}
// Operator
bool operator()(bool a, bool b) {
return a && b;
}
};
struct Or {
bool simd_reduce(bool val) {
return simd_any(val);
};
static constexpr constant bool init = false;
void atomic_update(
device mlx_atomic<unsigned int>* out,
bool val,
int elem_idx,
int offset = 0) {
if (val) {
bool4_or_uint update;
update.b = {false, false, false, false};
update.b[elem_idx] = true;
mlx_atomic_fetch_or_explicit(out, update.i, offset);
}
}
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
if (val) {
mlx_atomic_store_explicit(out, val, offset);
}
}
// Non atomic update
void update(device bool* out, bool val) {
*out |= val;
}
// Operator
bool operator()(bool a, bool b) {
return a || b;
}
};
template <typename U>
struct Sum {
template <typename T>
T simd_reduce(T val) {
return simd_sum(val);
};
static constexpr constant U init = U(0);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_add_explicit(out, val, offset);
}
// Operator
U operator()(U a, U b) {
return a + b;
}
};
template <typename U>
struct Prod {
template <typename T>
T simd_reduce(T val) {
return simd_product(val);
};
static constexpr constant U init = U(1);
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_mul_explicit(out, val, offset);
}
// Operator
U operator()(U a, U b) {
return a * b;
}
};
template <typename U>
struct Min {
template <typename T>
T simd_reduce(T val) {
return simd_min(val);
};
static constexpr constant U init = Limits<U>::max;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_min_explicit(out, val, offset);
}
// Operator
U operator()(U a, U b) {
return a < b ? a : b;
}
};
template <typename U>
struct Max {
template <typename T>
T simd_reduce(T val) {
return simd_max(val);
};
static constexpr constant U init = Limits<U>::min;
template <typename T>
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
mlx_atomic_fetch_max_explicit(out, val, offset);
}
// Operator
U operator()(U a, U b) {
return a > b ? a : b;
}
};

View File

@ -0,0 +1,492 @@
#include <metal_math>
#include <metal_simdgroup>
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/kernels/utils.h"
using namespace metal;
template <typename U>
struct CumSum {
static constexpr constant U init = static_cast<U>(0);
template <typename T>
U operator()(U a, T b) {
return a + b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_sum(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_sum(x);
}
};
template <typename U>
struct CumProd {
static constexpr constant U init = static_cast<U>(1.0f);
template <typename T>
U operator()(U a, T b) {
return a * b;
}
U simd_scan(U x) {
return simd_prefix_inclusive_product(x);
}
U simd_exclusive_scan(U x) {
return simd_prefix_exclusive_product(x);
}
};
template <>
struct CumProd<bool> {
static constexpr constant bool init = true;
template <typename T>
bool operator()(bool a, T b) {
return a & static_cast<bool>(b);
}
bool simd_scan(bool x) {
for (int i=1; i<=16; i*=2) {
bool other = simd_shuffle_up(x, i);
x &= other;
}
return x;
}
bool simd_exclusive_scan(bool x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMax {
static constexpr constant U init = Limits<U>::min;
template <typename T>
U operator()(U a, T b) {
return (a >= b) ? a : b;
}
U simd_scan(U x) {
for (int i=1; i<=16; i*=2) {
U other = simd_shuffle_up(x, i);
x = (x >= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename U>
struct CumMin {
static constexpr constant U init = Limits<U>::max;
template <typename T>
U operator()(U a, T b) {
return (a <= b) ? a : b;
}
U simd_scan(U x) {
for (int i=1; i<=16; i*=2) {
U other = simd_shuffle_up(x, i);
x = (x <= other) ? x : other;
}
return x;
}
U simd_exclusive_scan(U x) {
x = simd_scan(x);
return simd_shuffle_and_fill_up(x, init, 1);
}
};
template <typename T, typename U, int N_READS, bool reverse>
inline void load_unsafe(U values[N_READS], const device T * input) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
values[N_READS-i-1] = input[i];
}
} else {
for (int i=0; i<N_READS; i++) {
values[i] = input[i];
}
}
}
template <typename T, typename U, int N_READS, bool reverse>
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init;
}
} else {
for (int i=0; i<N_READS; i++) {
values[i] = (start + i < total) ? input[i] : init;
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_unsafe(U values[N_READS], device U * out) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
out[i] = values[N_READS-i-1];
}
} else {
for (int i=0; i<N_READS; i++) {
out[i] = values[i];
}
}
}
template <typename U, int N_READS, bool reverse>
inline void write_safe(U values[N_READS], device U * out, int start, int total) {
if (reverse) {
for (int i=0; i<N_READS; i++) {
if (start + N_READS - i - 1 < total) {
out[i] = values[N_READS-i-1];
}
}
} else {
for (int i=0; i<N_READS; i++) {
if (start + i < total) {
out[i] = values[i];
}
}
}
}
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
[[kernel]] void contiguous_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t & axis_size [[buffer(2)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
Op op;
// Position the pointers
in += (gid / lsize) * axis_size;
out += (gid / lsize) * axis_size;
// Compute the number of simd_groups
uint simd_groups = lsize / simd_size;
// Allocate memory
U prefix = Op::init;
U values[N_READS];
threadgroup U simdgroup_sums[32];
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize)
// Read block
// Compute inclusive scan of the block
// Compute inclusive scan per thread
// Compute exclusive scan of thread sums in simdgroup
// Write simdgroup sums in SM
// Compute exclusive scan of simdgroup sums
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value
// Write block
for (uint r = 0; r < ceildiv(axis_size, N_READS*lsize); r++) {
// Compute the block offset
uint offset = r*lsize*N_READS + lid*N_READS;
// Read the values
if (reverse) {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS);
} else {
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init);
}
} else {
if ((offset + N_READS) < axis_size) {
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
} else {
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init);
}
}
// Compute an inclusive scan per thread
for (int i=1; i<N_READS; i++) {
values[i] = op(values[i], values[i-1]);
}
// Compute exclusive scan of thread sums
U prev_thread = op.simd_exclusive_scan(values[N_READS-1]);
// Write simdgroup_sums to SM
if (simd_lane_id == simd_size - 1) {
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute exclusive scan of simdgroup_sums
if (simd_group_id == 0) {
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
simdgroup_sums[simd_lane_id] = prev_simdgroup;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute the output
for (int i=0; i<N_READS; i++) {
values[i] = op(values[i], prefix);
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
values[i] = op(values[i], prev_thread);
}
// Write the values
if (reverse) {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS);
} else {
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[axis_size-1] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS);
} else {
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size);
}
}
} else {
if (inclusive) {
if ((offset + N_READS) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset);
} else {
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size);
}
} else {
if (lid == 0 && offset == 0) {
out[0] = Op::init;
}
if ((offset + N_READS + 1) < axis_size) {
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
} else {
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size);
}
}
}
// Share the prefix
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
simdgroup_sums[0] = values[N_READS-1];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
prefix = simdgroup_sums[0];
}
}
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
[[kernel]] void strided_scan(
const device T* in [[buffer(0)]],
device U* out [[buffer(1)]],
const constant size_t & axis_size [[buffer(2)]],
const constant size_t & stride [[buffer(3)]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]]) {
Op op;
// Allocate memory
threadgroup U read_buffer[N_READS*32*32 + N_READS*32];
U values[N_READS];
U prefix[N_READS];
for (int i=0; i<N_READS; i++) {
prefix[i] = Op::init;
}
// Compute offsets
int offset = gid.y * axis_size * stride;
int global_index_x = gid.x * lsize.y * N_READS;
for (uint j=0; j<axis_size; j+=simd_size) {
// Calculate the indices for the current thread
uint index_y = j + lid.y;
uint check_index_y = index_y;
uint index_x = global_index_x + lid.x * N_READS;
if (reverse) {
index_y = axis_size - 1 - index_y;
}
// Read in SM
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
}
} else {
for (int i=0; i<N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
} else {
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Read strided into registers
for (int i=0; i<N_READS; i++) {
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
}
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously?
simdgroup_barrier(mem_flags::mem_threadgroup);
// Perform the scan
for (int i=0; i<N_READS; i++) {
values[i] = op.simd_scan(values[i]);
values[i] = op(values[i], prefix[i]);
prefix[i] = simd_shuffle(values[i], simd_size-1);
}
// Write to SM
for (int i=0; i<N_READS; i++) {
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Write to device memory
if (!inclusive) {
if (check_index_y == 0) {
if ((index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
} else {
for (int i=0; i<N_READS; i++) {
if ((index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = Op::init;
}
}
}
}
if (reverse) {
index_y -= 1;
check_index_y += 1;
} else {
index_y += 1;
check_index_y += 1;
}
}
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
for (int i=0; i<N_READS; i++) {
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
} else {
for (int i=0; i<N_READS; i++) {
if (check_index_y < axis_size && (index_x + i) < stride) {
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
}
}
}
}
}
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("contiguous_scan_" #name)]] \
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t & axis_size [[buffer(2)]], \
uint gid [[thread_position_in_grid]], \
uint lid [[thread_position_in_threadgroup]], \
uint lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \
template [[host_name("strided_scan_" #name)]] \
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
const device itype* in [[buffer(0)]], \
device otype* out [[buffer(1)]], \
const constant size_t & axis_size [[buffer(2)]], \
const constant size_t & stride [[buffer(3)]], \
uint2 gid [[thread_position_in_grid]], \
uint2 lid [[thread_position_in_threadgroup]], \
uint2 lsize [[threads_per_threadgroup]], \
uint simd_size [[threads_per_simdgroup]]);
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)
//instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4)
instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4)
instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4)
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4)
instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4)
//instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4)
instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4)
instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4)
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4)
instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4)
//instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4)
instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4)
instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4)
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4)
instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4)
//instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4)
instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4)
instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4)
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)

View File

@ -0,0 +1,88 @@
#include <cstdlib>
#include <future>
#include <memory>
#include "mlx/array.h"
#include "mlx/backend/metal/device.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
namespace mlx::core::metal {
int max_ops_per_buffer() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
return atoi(buff_str);
} else {
return 10;
}
};
static int max_ops_per_buffer_ = get_val();
return max_ops_per_buffer_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
MTL::CommandBuffer* increment_command_buffer(Stream s) {
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
if (command_buffer == nullptr ||
d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
if (command_buffer != nullptr) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); });
d.commit_command_buffer(s.index);
}
command_buffer = d.new_command_buffer(s.index);
}
d.increment_command_buffer_ops(s.index);
return command_buffer;
}
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p,
bool retain_graph) {
auto task =
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
for (auto& d : deps) {
d.wait();
}
auto s = arr.primitive().stream();
auto command_buffer = increment_command_buffer(s);
arr.primitive().eval_gpu(arr.inputs(), arr);
if (p) {
metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[retain_graph, s, arr, p = std::move(p)](
MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
p->set_value();
// Signal this thread to clear the pool on a synchroniztion.
scheduler::enqueue(s, []() {
thread_autorelease_pool()->release();
thread_autorelease_pool() =
NS::AutoreleasePool::alloc()->init();
});
scheduler::notify_task_completion(s);
});
metal::device(s.device).commit_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
if (!retain_graph) {
arr.detach();
}
});
}
};
return task;
}
} // namespace mlx::core::metal

28
mlx/backend/metal/metal.h Normal file
View File

@ -0,0 +1,28 @@
#pragma once
#include <future>
#include <memory>
#include <vector>
#include "mlx/array.h"
#include "mlx/stream.h"
namespace mlx::core::metal {
constexpr bool is_available() {
#ifdef _METAL_
return true;
#else
return false;
#endif
}
void new_stream(Stream stream);
std::function<void()> make_task(
array& arr,
std::vector<std::shared_future<void>> deps,
std::shared_ptr<std::promise<void>> p,
bool retain_graph);
} // namespace mlx::core::metal

View File

@ -0,0 +1,82 @@
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels/defines.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
if (!is_floating_point(out.dtype())) {
throw std::runtime_error(
"[softmax] Does not support non-floating point types.");
}
auto& s = stream();
auto& d = metal::device(s.device);
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) {
if (x.strides()[x.ndim() - 1] == 1) {
return x;
} else {
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
copies.push_back(x_copy);
return x_copy;
}
};
const array& in = check_input(inputs[0]);
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
const int simd_size = 32;
const int n_reads = SOFTMAX_N_READS;
const int looped_limit = SOFTMAX_LOOPED_LIMIT;
std::string op_name = "softmax_";
if (axis_size > looped_limit) {
op_name += "looped_";
}
op_name += type_to_name(out);
auto compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name);
MTL::Size grid_dims, group_dims;
if (axis_size <= looped_limit) {
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
size_t threadgroup_size = simd_size * simds_needed;
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
} else {
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
size_t n_threads = n_rows * threadgroup_size;
grid_dims = MTL::Size(n_threads, 1, 1);
group_dims = MTL::Size(threadgroup_size, 1, 1);
}
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
} // namespace mlx::core

336
mlx/backend/metal/sort.cpp Normal file
View File

@ -0,0 +1,336 @@
#include <algorithm>
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
namespace {
template <bool ARGSORT>
void single_block_sort(
const Stream& s,
metal::Device& d,
const array& in,
array& out,
int axis,
int bn,
int tn) {
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
std::vector<size_t> nc_str = in.strides();
nc_str.erase(nc_str.begin() + axis);
std::vector<int> nc_shape = in.shape();
nc_shape.erase(nc_shape.begin() + axis);
int nc_dim = nc_shape.size();
int size_sorted_axis = in.shape(axis);
int stride_sorted_axis = in.strides()[axis];
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
// Check if remaining strides are contiguous
bool contiguous_write = true;
if (axis != in.ndim() - 1 && axis != 0) {
for (int i = 0; i < nc_str.size() - 1; ++i) {
size_t expected = nc_str[i + 1] * nc_str[i + 1];
contiguous_write &= (nc_str[i] == expected);
}
}
// Prepare kernel name
std::ostringstream kname;
if (ARGSORT) {
kname << "arg_";
}
kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out)
<< "_bn" << bn << "_tn" << tn;
if (!contiguous_write) {
kname << "_nc";
}
// Prepare command encoder
auto compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
// Set inputs
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, out, 1);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
if (contiguous_write) {
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
} else {
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
}
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
template <bool ARGSORT>
void multi_block_sort(
const Stream& s,
metal::Device& d,
const array& in,
array& out,
int axis,
int bn,
int tn,
int n_blocks) {
// Prepare shapes
int n_rows = in.size() / in.shape(axis);
std::vector<size_t> nc_str = in.strides();
nc_str.erase(nc_str.begin() + axis);
std::vector<int> nc_shape = in.shape();
nc_shape.erase(nc_shape.begin() + axis);
int nc_dim = nc_shape.size();
int size_sorted_axis = in.shape(axis);
int stride_sorted_axis = in.strides()[axis];
// Make temporary copies
array dev_vals_0({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
array dev_vals_1({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
array dev_idxs_0({n_rows, size_sorted_axis}, uint32, nullptr, {});
array dev_idxs_1({n_rows, size_sorted_axis}, uint32, nullptr, {});
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
// Do allocations
dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes()));
dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes()));
dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes()));
dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes()));
block_partitions.set_data(
allocator::malloc_or_wait(block_partitions.nbytes()));
std::vector<array> copies = {
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
// Prepare command encoder
auto compute_encoder = d.get_command_encoder(s.index);
// Do blockwise sort
{
std::ostringstream kname;
kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_"
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0);
set_array_buffer(compute_encoder, dev_vals_0, 1);
set_array_buffer(compute_encoder, dev_idxs_0, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Do merges
bool ping = false;
array dev_vals_in = dev_vals_0;
array dev_idxs_in = dev_idxs_0;
array dev_vals_out = dev_vals_1;
array dev_idxs_out = dev_idxs_1;
for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) {
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
ping = !ping;
// Do partiton
{
std::ostringstream kname;
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, block_partitions, 0);
set_array_buffer(compute_encoder, dev_vals_in, 1);
set_array_buffer(compute_encoder, dev_idxs_in, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
// Do merge
{
std::ostringstream kname;
kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_"
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, block_partitions, 0);
set_array_buffer(compute_encoder, dev_vals_in, 1);
set_array_buffer(compute_encoder, dev_idxs_in, 2);
set_array_buffer(compute_encoder, dev_vals_out, 3);
set_array_buffer(compute_encoder, dev_idxs_out, 4);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
MTL::Size group_dims = MTL::Size(bn, 1, 1);
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
}
}
// Copy outputs with appropriate strides
array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out;
if (axis == strided_out_arr.ndim() - 1) {
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
} else {
std::vector<int> strided_out_shape = strided_out_arr.shape();
std::vector<size_t> strided_out_str = strided_out_arr.strides();
int out_axis_shape = strided_out_shape[axis];
int out_axis_str = strided_out_str[axis];
strided_out_shape.erase(strided_out_shape.begin() + axis);
strided_out_str.erase(strided_out_str.begin() + axis);
strided_out_shape.push_back(out_axis_shape);
strided_out_str.push_back(out_axis_str);
array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {});
strided_out_slice.copy_shared_buffer(
strided_out_arr,
strided_out_str,
strided_out_arr.flags(),
strided_out_arr.size(),
0);
copy_gpu_inplace(strided_out_slice, out, CopyType::General, s);
}
// Clear copies
d.get_command_buffer(s.index)->addCompletedHandler(
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
}
template <bool ARGSORT>
void gpu_merge_sort(
const Stream& s,
metal::Device& d,
const array& in,
array& out,
int axis_) {
// Get size info
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
int size_sorted_axis = in.shape(axis);
// Get kernel size
int tn = 8;
int bn = 128;
int potential_bn = (size_sorted_axis + tn - 1) / tn;
if (potential_bn > 256) {
bn = 512;
} else if (potential_bn > 128) {
bn = 256;
} else {
bn = 128;
}
if (bn == 512 && size_of(in.dtype()) > 4) {
bn = 256;
}
int n_per_block = bn * tn;
int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
if (n_blocks > 1) {
return multi_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn, n_blocks);
} else {
return single_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn);
}
}
} // namespace
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<true>(s, d, in, out, axis_);
}
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<false>(s, d, in, out, axis_);
}
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
// We direct arg partition to sort for now
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<true>(s, d, in, out, axis_);
}
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
// We direct partition to sort for now
assert(inputs.size() == 1);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
auto& in = inputs[0];
gpu_merge_sort<false>(s, d, in, out, axis_);
}
} // namespace mlx::core

View File

@ -0,0 +1,7 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
)

View File

@ -0,0 +1,77 @@
#include "mlx/primitives.h"
#define NO_GPU(func) \
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
throw std::runtime_error(#func " has no GPU implementation."); \
}
namespace mlx::core {
NO_GPU(Abs)
NO_GPU(Add)
NO_GPU(Arange)
NO_GPU(ArcCos)
NO_GPU(ArcCosh)
NO_GPU(ArcSin)
NO_GPU(ArcSinh)
NO_GPU(ArcTan)
NO_GPU(ArcTanh)
NO_GPU(ArgPartition)
NO_GPU(ArgReduce)
NO_GPU(ArgSort)
NO_GPU(AsType)
NO_GPU(AsStrided)
NO_GPU(Broadcast)
NO_GPU(Concatenate)
NO_GPU(Convolution)
NO_GPU(Copy)
NO_GPU(Cos)
NO_GPU(Cosh)
NO_GPU(Divide)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(Exp)
NO_GPU(FFT)
NO_GPU(Full)
NO_GPU(Gather)
NO_GPU(Greater)
NO_GPU(GreaterEqual)
NO_GPU(Less)
NO_GPU(LessEqual)
NO_GPU(Load)
NO_GPU(Log)
NO_GPU(Log1p)
NO_GPU(LogicalNot)
NO_GPU(LogAddExp)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)
NO_GPU(Multiply)
NO_GPU(Negative)
NO_GPU(NotEqual)
NO_GPU(Pad)
NO_GPU(Partition)
NO_GPU(Power)
NO_GPU(RandomBits)
NO_GPU(Reduce)
NO_GPU(Reshape)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(Sigmoid)
NO_GPU(Sign)
NO_GPU(Sin)
NO_GPU(Sinh)
NO_GPU(Slice)
NO_GPU(Softmax)
NO_GPU(Sort)
NO_GPU(Square)
NO_GPU(Sqrt)
NO_GPU(StopGradient)
NO_GPU(Subtract)
NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)
} // namespace mlx::core

29
mlx/device.cpp Normal file
View File

@ -0,0 +1,29 @@
#include "mlx/device.h"
#include "mlx/backend/metal/metal.h"
namespace mlx::core {
static Device default_device_{
metal::is_available() ? Device::gpu : Device::cpu};
const Device& default_device() {
return default_device_;
}
void set_default_device(const Device& d) {
if (!metal::is_available() && d == Device::gpu) {
throw std::invalid_argument(
"[set_default_device] Cannot set gpu device without gpu backend.");
}
default_device_ = d;
}
bool operator==(const Device& lhs, const Device& rhs) {
return lhs.type == rhs.type && lhs.index == rhs.index;
}
bool operator!=(const Device& lhs, const Device& rhs) {
return !(lhs == rhs);
}
} // namespace mlx::core

99
mlx/dtype.h Normal file
View File

@ -0,0 +1,99 @@
#pragma once
#include <complex>
#include <cstdint>
#include <ostream>
#include <string>
#include "mlx/types/complex.h"
#include "mlx/types/half_types.h"
namespace mlx::core {
struct Dtype {
enum class Val {
bool_,
uint8,
uint16,
uint32,
uint64,
int8,
int16,
int32,
int64,
float16,
float32,
bfloat16,
complex64,
};
enum class Kind {
b, /* bool */
u, /* unsigned int */
i, /* signed int */
f, /* float */
c, /* complex */
V, /* void - used for brain float */
};
Val val;
const uint8_t size;
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
constexpr operator Val() const {
return val;
};
};
inline bool is_available(const Dtype& dtype) {
return true;
}
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
Dtype promote_types(const Dtype& t1, const Dtype& t2);
inline uint8_t size_of(const Dtype& t) {
return t.size;
}
Dtype::Kind kindof(const Dtype& t);
inline bool is_unsigned(const Dtype& t) {
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
}
inline bool is_floating_point(const Dtype& t) {
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
kindof(t) == Dtype::Kind::c;
}
inline bool is_integral(const Dtype& t) {
return !(is_floating_point(t));
}
template <typename T>
struct TypeToDtype {
operator Dtype();
};
// Array protocol typestring for Dtype
std::string dtype_to_array_protocol(const Dtype& t);
// Dtype from array protocol type string
Dtype dtype_from_array_protocol(const std::string& t);
} // namespace mlx::core

190
mlx/fft.cpp Normal file
View File

@ -0,0 +1,190 @@
#include <numeric>
#include <set>
#include "mlx/fft.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
namespace mlx::core::fft {
array fft_impl(
const array& a,
std::vector<int> n,
const std::vector<int>& axes,
bool real,
bool inverse,
StreamOrDevice s) {
if (a.ndim() < 1) {
throw std::invalid_argument(
"[fftn] Requires array with at least one dimension.");
}
if (n.size() != axes.size()) {
throw std::invalid_argument("[fftn] Shape and axes have different sizes.");
}
if (axes.empty()) {
return a;
}
std::vector<size_t> valid_axes;
for (int ax : axes) {
valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);
}
std::set<int> unique_axes(valid_axes.begin(), valid_axes.end());
if (unique_axes.size() != axes.size()) {
std::ostringstream msg;
msg << "[fftn] Duplicated axis received " << axes;
throw std::invalid_argument(msg.str());
}
if (*unique_axes.begin() < 0 || *unique_axes.rbegin() >= a.ndim()) {
std::ostringstream msg;
msg << "[fftn] Invalid axis received for array with " << a.ndim()
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
// In the following shape manipulations there are three cases to consdier:
// 1. In a complex to complex transform (fftn / ifftn) the output
// and input shapes are the same.
// 2. In a real to complex transform (rfftn) n specifies the input dims
// and the output dims are n[i] / 2 + 1
// 3 In a complex to real transform (irfftn) n specifies the output dims
// and the input dims are n[i] / 2 + 1
if (std::any_of(n.begin(), n.end(), [](auto i) { return i <= 0; })) {
std::ostringstream msg;
msg << "[fftn] Invalid FFT output size requested " << n;
throw std::invalid_argument(msg.str());
}
std::vector<int> in_shape = a.shape();
for (int i = 0; i < valid_axes.size(); ++i) {
in_shape[valid_axes[i]] = n[i];
}
if (real && inverse) {
in_shape[valid_axes.back()] = n.back() / 2 + 1;
}
bool any_greater = false;
bool any_less = false;
for (int i = 0; i < in_shape.size(); ++i) {
any_greater |= in_shape[i] > a.shape()[i];
any_less |= in_shape[i] < a.shape()[i];
}
auto in = a;
if (any_less) {
in = slice(in, std::vector<int>(in.ndim(), 0), in_shape, s);
}
if (any_greater) {
// Pad with zeros
auto tmp = zeros(in_shape, a.dtype(), s);
in = scatter(tmp, std::vector<array>{}, in, std::vector<int>{}, s);
}
auto out_shape = in_shape;
if (real) {
auto ax = valid_axes.back();
out_shape[ax] = inverse ? n.back() : out_shape[ax] / 2 + 1;
}
auto in_type = real && !inverse ? float32 : complex64;
auto out_type = real && inverse ? float32 : complex64;
return array(
out_shape,
out_type,
std::make_unique<FFT>(to_stream(s), valid_axes, inverse, real),
{astype(in, in_type, s)});
}
array fft_impl(
const array& a,
const std::vector<int>& axes,
bool real,
bool inverse,
StreamOrDevice s) {
std::vector<int> n;
for (auto ax : axes) {
n.push_back(a.shape(ax));
}
if (real && inverse) {
n.back() = (n.back() - 1) * 2;
}
return fft_impl(a, n, axes, real, inverse, s);
}
array fft_impl(const array& a, bool real, bool inverse, StreamOrDevice s) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return fft_impl(a, axes, real, inverse, s);
}
array fftn(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
return fft_impl(a, n, axes, false, false, s);
}
array fftn(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
return fft_impl(a, axes, false, false, s);
}
array fftn(const array& a, StreamOrDevice s /* = {} */) {
return fft_impl(a, false, false, s);
}
array ifftn(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
return fft_impl(a, n, axes, false, true, s);
}
array ifftn(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
return fft_impl(a, axes, false, true, s);
}
array ifftn(const array& a, StreamOrDevice s /* = {} */) {
return fft_impl(a, false, true, s);
}
array rfftn(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
return fft_impl(a, n, axes, true, false, s);
}
array rfftn(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
return fft_impl(a, axes, true, false, s);
}
array rfftn(const array& a, StreamOrDevice s /* = {} */) {
return fft_impl(a, true, false, s);
}
array irfftn(
const array& a,
const std::vector<int>& n,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
return fft_impl(a, n, axes, true, true, s);
}
array irfftn(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s /* = {} */) {
return fft_impl(a, axes, true, true, s);
}
array irfftn(const array& a, StreamOrDevice s /* = {} */) {
return fft_impl(a, true, true, s);
}
} // namespace mlx::core::fft

21
mlx/graph_utils.h Normal file
View File

@ -0,0 +1,21 @@
#pragma once
#include "mlx/array.h"
namespace mlx::core {
void print_graph(std::ostream& os, const std::vector<array>& outputs);
template <typename... Arrays>
void print_graph(std::ostream& os, Arrays... outputs) {
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
}
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
template <typename... Arrays>
void export_to_dot(std::ostream& os, Arrays... outputs) {
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
}
} // namespace mlx::core

2323
mlx/ops.cpp Normal file

File diff suppressed because it is too large Load Diff

16
mlx/transforms_impl.h Normal file
View File

@ -0,0 +1,16 @@
namespace mlx::core::detail {
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs,
const std::vector<int>& in_axes);
std::vector<array> vmap_replace(
const std::vector<array>& inputs,
const std::vector<array>& s_inputs,
const std::vector<array>& s_outputs,
const std::vector<int>& in_axes,
const std::vector<int>& out_axes);
} // namespace mlx::core::detail

75
mlx/types/complex.h Normal file
View File

@ -0,0 +1,75 @@
#pragma once
#include <complex>
#include "mlx/types/half_types.h"
namespace mlx::core {
struct complex64_t;
template <typename T>
static constexpr bool can_convert_to_complex64 =
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
struct complex64_t : public std::complex<float> {
complex64_t(float v, float u) : std::complex<float>(v, u){};
complex64_t(std::complex<float> v) : std::complex<float>(v){};
template <
typename T,
typename = typename std::enable_if<can_convert_to_complex64<T>>::type>
complex64_t(T x) : std::complex<float>(x){};
operator float() const {
return real();
};
};
inline bool operator>=(const complex64_t& a, const complex64_t& b) {
return (a.real() > b.real()) ||
(a.real() == b.real() && a.imag() >= b.imag());
}
inline bool operator>(const complex64_t& a, const complex64_t& b) {
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
}
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
return operator>=(b, a);
}
inline bool operator<(const complex64_t& a, const complex64_t& b) {
return operator>(b, a);
}
inline complex64_t operator-(const complex64_t& v) {
return -static_cast<std::complex<float>>(v);
}
// clang-format off
#define complex_binop_helper(_op_, _operator_, itype) \
inline complex64_t _operator_(itype x, const complex64_t& y) { \
return x _op_ static_cast<std::complex<float>>(y); \
} \
inline complex64_t _operator_(const complex64_t& x, itype y) { \
return static_cast<std::complex<float>>(x) _op_ y; \
}
#define complex_binop(_op_, _operator_) \
inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
return static_cast<std::complex<float>>(x) \
_op_ static_cast<std::complex<float>>(y); \
} \
complex_binop_helper(_op_, _operator_, bool) \
complex_binop_helper(_op_, _operator_, uint32_t) \
complex_binop_helper(_op_, _operator_, uint64_t) \
complex_binop_helper(_op_, _operator_, int32_t) \
complex_binop_helper(_op_, _operator_, int64_t) \
complex_binop_helper(_op_, _operator_, float16_t) \
complex_binop_helper(_op_, _operator_, bfloat16_t) \
complex_binop_helper(_op_, _operator_, const std::complex<float>&) \
complex_binop_helper(_op_, _operator_, float)
// clang-format on
complex_binop(+, operator+)
} // namespace mlx::core

232
mlx/types/fp16.h Normal file
View File

@ -0,0 +1,232 @@
#pragma once
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <vector>
#define __MLX_HALF_NAN__ 0x7D00
namespace mlx::core {
namespace {
union float_bits_fp16 {
float f;
uint32_t u;
};
} // namespace
struct _MLX_Float16 {
uint16_t bits_;
// Default constructor
_MLX_Float16() = default;
// Default copy constructor
_MLX_Float16(_MLX_Float16 const&) = default;
// Appease std::vector<bool> for being special
_MLX_Float16& operator=(std::vector<bool>::reference x) {
bits_ = x;
return *this;
}
_MLX_Float16& operator=(const float& x) {
return (*this = _MLX_Float16(x));
}
// From float32
_MLX_Float16(const float& x) : bits_(0) {
// Conversion following
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
// Union
float_bits_fp16 in;
// Take fp32 bits
in.f = x;
// Find and take sign bit
uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
uint16_t x_sign_16 = (x_sign_32 >> 16);
if (std::isnan(x)) {
bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
} else {
// Union
float_bits_fp16 inf_scale, zero_scale, magic_bits;
// Find exponent bits and take the max supported by half
uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
uint32_t max_expo_32 = uint32_t(0x38800000);
x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
x_expo_32 += uint32_t(15) << 23;
// Handle scaling to inf as needed
inf_scale.u = uint32_t(0x77800000);
zero_scale.u = uint32_t(0x08800000);
// Combine with magic and let addition do rouding
magic_bits.u = x_expo_32;
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
// Take the lower 5 bits of the exponent
uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
// Collect the lower 12 bits which have the mantissa
uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
// Combine sign, exp and mantissa
bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
}
}
// To float32
operator float() const {
// Conversion following
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
// Union
float_bits_fp16 out;
uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
uint32_t base = (bits_ << 16);
uint32_t two_base = base + base;
uint32_t denorm_max = 1u << 27;
if (two_base < denorm_max) {
out.u = uint32_t(126) << 23; // magic mask
out.u |= (two_base >> 17); // Bits from fp16
out.f -= 0.5f; // magic bias
} else {
out.u = uint32_t(0xE0) << 23; // exponent offset
out.u += (two_base >> 4); // Bits from fp16
float out_unscaled = out.f; // Store value
out.u = uint32_t(0x7800000); // exponent scale
out.f *= out_unscaled;
}
// Add sign
out.u |= x_sign_32;
return out.f;
}
};
#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
inline otype __operator__(atype lhs, btype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
} \
inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
// Operators
#define half_binop(__op__, __operator__) \
half_binop_base( \
__op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
half_binop_helper(__op__, __operator__, float, float, float); \
half_binop_helper(__op__, __operator__, double, double, double); \
half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
half_binop(+, operator+);
half_binop(-, operator-);
half_binop(*, operator*);
half_binop(/, operator/);
#undef half_binop
// Comparison ops
#define half_compop(__op__, __operator__) \
half_binop_base( \
__op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
half_binop_helper(__op__, __operator__, bool, float, float); \
half_binop_helper(__op__, __operator__, bool, double, double); \
half_binop_helper(__op__, __operator__, bool, int32_t, float); \
half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
half_binop_helper(__op__, __operator__, bool, int64_t, float); \
half_binop_helper(__op__, __operator__, bool, uint64_t, float);
half_compop(>, operator>);
half_compop(<, operator<);
half_compop(>=, operator>=);
half_compop(<=, operator<=);
half_compop(==, operator==);
half_compop(!=, operator!=);
#undef half_compop
// Negative
inline _MLX_Float16 operator-(_MLX_Float16 lhs) {
return -static_cast<float>(lhs);
}
// Inplace ops
#define half_inplace_op(__op__, __operator__) \
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
lhs = lhs __op__ rhs; \
return lhs; \
} \
inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
lhs = lhs __op__ rhs; \
return lhs; \
}
half_inplace_op(+, operator+=);
half_inplace_op(-, operator-=);
half_inplace_op(*, operator*=);
half_inplace_op(/, operator/=);
#undef half_inplace_op
// Bitwise ops
#define half_bitop(__op__, __operator__) \
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
_MLX_Float16 out; \
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
return out; \
} \
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
_MLX_Float16 out; \
out.bits_ = lhs.bits_ __op__ rhs; \
return out; \
} \
inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
_MLX_Float16 out; \
out.bits_ = lhs __op__ rhs.bits_; \
return out; \
}
half_bitop(|, operator|);
half_bitop(&, operator&);
half_bitop(^, operator^);
#undef half_bitop
#define half_inplace_bitop(__op__, __operator__) \
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
return lhs; \
} \
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
lhs.bits_ = lhs.bits_ __op__ rhs; \
return lhs; \
}
half_inplace_bitop(|, operator|=);
half_inplace_bitop(&, operator&=);
half_inplace_bitop(^, operator^=);
#undef half_inplace_bitop
} // namespace mlx::core

3
pyproject.toml Normal file
View File

@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"]
build-backend = "setuptools.build_meta"

View File

@ -0,0 +1,124 @@
import math
from typing import Union
import mlx.core as mx
from mlx.nn.layers.base import Module
class Conv1d(Module):
"""Applies a 1-dimensional convolution over the multi-channel input sequence.
The channels are expected to be last i.e. the input shape should be ``NLC`` where:
- ``N`` is the batch dimension
- ``L`` is the sequence length
- ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels
out_channels (int): The number of output channels
kernel_size (int): The size of the convolution filters
stride (int, optional): The stride when applying the filter.
Default: 1.
padding (int, optional): How many positions to 0-pad the input with.
Default: 0.
bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = True,
):
super().__init__()
scale = math.sqrt(1 / (in_channels * kernel_size))
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, kernel_size, in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
self.padding = padding
self.stride = stride
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
f"padding={self.padding}, bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv1d(x, self.weight, self.stride, self.padding)
if "bias" in self:
y = y + self.bias
return y
class Conv2d(Module):
"""Applies a 2-dimensional convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
- ``N`` is the batch dimension
- ``H`` is the input image height
- ``W`` is the input image width
- ``C`` is the number of input channels
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int or tuple): The size of the convolution filters.
stride (int or tuple, optional): The size of the stride when
applying the filter. Default: 0.
padding (int or tuple, optional): How many positions to 0-pad
the input with. Default: 0.
bias (bool, optional): If ``True`` add a learnable bias to the
output. Default: ``True``
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple],
stride: Union[int, tuple] = 1,
padding: Union[int, tuple] = 0,
bias: bool = True,
):
super().__init__()
kernel_size, stride, padding = map(
lambda x: (x, x) if isinstance(x, int) else x,
(kernel_size, stride, padding),
)
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, *kernel_size, in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
self.padding = padding
self.stride = stride
def _extra_repr(self):
return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
f"padding={self.padding}, bias={'bias' in self}"
)
def __call__(self, x):
y = mx.conv2d(x, self.weight, self.stride, self.padding)
if "bias" in self:
y = y + self.bias
return y

View File

@ -0,0 +1,28 @@
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
class Embedding(Module):
"""Implements a simple lookup table that maps each input integer to a
high-dimensional vector.
Typically used to embed discrete tokens for processing by neural networks.
Args:
num_embeddings (int): How many possible discrete tokens can we embed.
Usually called the vocabulary size.
dims (int): The dimensionality of the embeddings.
"""
def __init__(self, num_embeddings: int, dims: int):
super().__init__()
scale = math.sqrt(1 / dims)
self.weight = mx.random.normal((num_embeddings, dims)) * scale
def _extra_repr(self):
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
def __call__(self, x):
return self.weight[x]

View File

@ -0,0 +1,34 @@
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
class Linear(Module):
"""Applies an affine transformation to the input.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool): If set to False then the layer will not use a bias
"""
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
super().__init__()
scale = math.sqrt(1 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims),
)
if bias:
self.bias = mx.zeros((output_dims,))
def _extra_repr(self):
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
def __call__(self, x):
x = x @ self.weight.T
if "bias" in self:
x = x + self.bias
return x

19
python/src/load.h Normal file
View File

@ -0,0 +1,19 @@
#pragma once
#include <pybind11/pybind11.h>
#include <unordered_map>
#include <variant>
#include "mlx/ops.h"
namespace py = pybind11;
using namespace mlx::core;
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
void mlx_save_helper(py::object file, array a, bool retain_graph = true);
void mlx_savez_helper(
py::object file,
py::args args,
const py::kwargs& kwargs,
bool compressed = false);

31
python/src/mlx.cpp Normal file
View File

@ -0,0 +1,31 @@
#include <pybind11/pybind11.h>
#define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x)
namespace py = pybind11;
void init_array(py::module_&);
void init_device(py::module_&);
void init_stream(py::module_&);
void init_metal(py::module_&);
void init_ops(py::module_&);
void init_transforms(py::module_&);
void init_random(py::module_&);
void init_fft(py::module_&);
PYBIND11_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple Silicon.";
auto reprlib_fix = py::module_::import("mlx._reprlib_fix");
init_device(m);
init_stream(m);
init_array(m);
init_metal(m);
init_ops(m);
init_transforms(m);
init_random(m);
init_fft(m);
m.attr("__version__") = TOSTRING(_VERSION_);
}

723
python/src/transforms.cpp Normal file
View File

@ -0,0 +1,723 @@
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <algorithm>
#include <fstream>
#include <numeric>
#include <sstream>
#include "mlx/array.h"
#include "mlx/graph_utils.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
using IntOrVec = std::variant<int, std::vector<int>>;
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
template <typename T>
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
std::vector<T> vals;
if (auto pv = std::get_if<T>(&v); pv) {
vals.push_back(*pv);
} else {
vals = std::get<std::vector<T>>(v);
}
return vals;
}
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
std::function<void(py::handle)> recurse;
recurse = [&](py::handle subtree) {
if (py::isinstance<py::list>(subtree) ||
py::isinstance<py::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
}
} else if (py::isinstance<py::dict>(subtree)) {
for (auto item : py::cast<py::dict>(subtree)) {
recurse(item.second);
}
} else {
visitor(subtree);
}
};
recurse(tree);
}
template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<py::object>& subtrees) {
int len = py::cast<T>(subtrees[0]).size();
for (auto& subtree : subtrees) {
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
}
}
}
py::object tree_map(
const std::vector<py::object>& trees,
std::function<py::object(const std::vector<py::object>&)> transform) {
std::function<py::object(const std::vector<py::object>&)> recurse;
recurse = [&](const std::vector<py::object>& subtrees) {
if (py::isinstance<py::list>(subtrees[0])) {
py::list l;
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::list>(subtrees[j])) {
items[j] = py::cast<py::list>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l.append(recurse(items));
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::tuple>(subtrees[0])) {
// Check the rest of the subtrees
std::vector<py::object> items(subtrees.size());
int len = py::cast<py::tuple>(subtrees[0]).size();
py::tuple l(len);
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
for (int i = 0; i < len; ++i) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::tuple>(subtrees[j])) {
items[j] = py::cast<py::tuple>(subtrees[j])[i];
} else {
items[j] = subtrees[j];
}
}
l[i] = recurse(items);
}
return py::cast<py::object>(l);
} else if (py::isinstance<py::dict>(subtrees[0])) {
std::vector<py::object> items(subtrees.size());
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
py::dict d;
for (auto item : py::cast<py::dict>(subtrees[0])) {
for (int j = 0; j < subtrees.size(); ++j) {
if (py::isinstance<py::dict>(subtrees[j])) {
auto subdict = py::cast<py::dict>(subtrees[j]);
if (!subdict.contains(item.first)) {
throw std::invalid_argument(
"[tree_map] Tree is not a valid prefix tree of the first tree.");
}
items[j] = subdict[item.first];
} else {
items[j] = subtrees[j];
}
}
d[item.first] = recurse(items);
}
return py::cast<py::object>(d);
} else {
return transform(subtrees);
}
};
return recurse(trees);
}
py::object tree_map(
py::object tree,
std::function<py::object(py::handle)> transform) {
return tree_map({tree}, [&](std::vector<py::object> inputs) {
return transform(inputs[0]);
});
}
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
std::vector<array> flat_tree;
tree_visit(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
flat_tree.push_back(py::cast<array>(obj));
} else if (strict) {
throw std::invalid_argument("Argument is not an array");
}
});
return flat_tree;
}
py::object tree_unflatten(
py::object tree,
const std::vector<array>& values,
int index = 0) {
return tree_map(tree, [&](py::handle obj) {
if (py::isinstance<array>(obj)) {
return py::cast(values[index++]);
} else {
return py::cast<py::object>(obj);
}
});
}
auto validate_argnums_argnames(
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
auto vec_names = to_vector(argnames);
if (!argnums.has_value()) {
// argnums was not provided and argnames was empty
if (vec_names.empty()) {
return std::make_pair(std::vector<int>{0}, vec_names);
} else {
return std::make_pair(std::vector<int>{}, vec_names);
}
}
return std::make_pair(to_vector(*argnums), vec_names);
}
auto py_value_and_grad(
const py::function& fun,
std::vector<int> argnums,
std::vector<std::string> argnames,
const std::string& error_msg_tag,
bool scalar_func_only) {
// Sanitize argnums
if (argnums.size() == 0 && argnames.size() == 0) {
throw std::invalid_argument(
error_msg_tag + " Gradient wrt no argument requested");
}
if (argnums.size() > 0) {
std::sort(argnums.begin(), argnums.end());
if (argnums[0] < 0) {
std::ostringstream msg;
msg << error_msg_tag
<< " Can't compute the gradient of negative argument index "
<< argnums[0];
throw std::invalid_argument(msg.str());
}
}
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
const py::args& args, const py::kwargs& kwargs) {
// Sanitize the input
if (argnums.size() > 0 && argnums.back() >= args.size()) {
std::ostringstream msg;
msg << error_msg_tag << " Can't compute the gradient of argument index "
<< argnums.back() << " because the function is called with only "
<< args.size() << " arguments.";
throw std::invalid_argument(msg.str());
}
for (auto& key : argnames) {
if (!kwargs.contains(key)) {
std::ostringstream msg;
msg << error_msg_tag
<< " Can't compute the gradient of keyword argument '" << key
<< "' because the function is called with the "
<< "following keyword arguments {";
for (auto item : kwargs) {
msg << item.first.cast<std::string>() << ",";
}
msg << "}";
throw std::invalid_argument(msg.str());
}
}
// Collect the arrays
std::vector<array> arrays;
std::vector<int> counts(1, 0);
for (auto i : argnums) {
auto argsi = tree_flatten(args[i]);
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
counts.push_back(argsi.size());
}
for (auto& key : argnames) {
auto argsk = tree_flatten(kwargs[key.c_str()]);
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
counts.push_back(argsk.size());
}
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
std::vector<int> gradient_indices(arrays.size());
std::iota(gradient_indices.begin(), gradient_indices.end(), 0);
// value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values
py::object py_value_out;
auto value_and_grads = value_and_grad(
[&fun,
&args,
&kwargs,
&argnums,
&argnames,
&counts,
&py_value_out,
&error_msg_tag,
scalar_func_only](const std::vector<array>& a) {
// Copy the arguments
py::args args_cpy = py::tuple(args.size());
py::kwargs kwargs_cpy = py::kwargs();
int j = 0;
for (int i = 0; i < args.size(); ++i) {
if (j < argnums.size() && i == argnums[j]) {
args_cpy[i] = tree_unflatten(args[i], a, counts[j]);
j++;
} else {
args_cpy[i] = args[i];
}
}
for (auto& key : argnames) {
kwargs_cpy[key.c_str()] =
tree_unflatten(kwargs[key.c_str()], a, counts[j]);
j++;
}
for (auto item : kwargs) {
if (kwargs_cpy.contains(item.first)) {
continue;
}
kwargs_cpy[item.first] = item.second;
}
// Call the python function
py_value_out = fun(*args_cpy, **kwargs_cpy);
// Validate the return value of the python function
if (!py::isinstance<array>(py_value_out)) {
if (scalar_func_only) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be a "
<< "scalar array; but " << py_value_out.get_type()
<< " was returned.";
throw std::invalid_argument(msg.str());
}
if (!py::isinstance<py::tuple>(py_value_out)) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a "
<< "scalar array or a tuple with the first value being a "
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but "
<< py_value_out.get_type() << " was returned.";
throw std::invalid_argument(msg.str());
}
py::tuple ret = py::cast<py::tuple>(py_value_out);
if (ret.size() == 0) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a "
<< "scalar array or a non-empty tuple. The first value should be a "
<< "scalar array and the rest can be anything. Instead, "
<< "we got an empty tuple.";
throw std::invalid_argument(msg.str());
}
if (!py::isinstance<array>(ret[0])) {
std::ostringstream msg;
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a "
<< "scalar array or a tuple with the first value being a "
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
<< "was a tuple with the first value being of type "
<< ret[0].get_type() << " .";
throw std::invalid_argument(msg.str());
}
}
return tree_flatten(py_value_out, false);
},
gradient_indices)(arrays);
auto value = value_and_grads.first;
auto gradients = value_and_grads.second;
// Put the gradients back in their container.
// We have the following cases:
//
// 1. Single python positional argument has a gradient (eg argnums=[0])
// 2. Many python positional arguments have gradients (eg argnums=[0, 1])
// 3. A python keyword argument has gradients
//
// In case 1 we return the original python variable but with the gradients.
// In case 2 we return a tuple of the above.
// In case 3 we return a tuple containing a tuple and dict (sth like
// (tuple(), dict(x=mx.array(5))) ).
py::object positional_grads;
py::object keyword_grads;
py::object py_grads;
// Collect the gradients for the positional arguments
if (argnums.size() == 1) {
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
} else if (argnums.size() > 1) {
py::tuple grads_(argnums.size());
for (int i = 0; i < argnums.size(); i++) {
grads_[i] = tree_unflatten(args[argnums[i]], gradients, counts[i]);
}
positional_grads = py::cast<py::object>(grads_);
} else {
positional_grads = py::none();
}
// No keyword argument gradients so return the tuple of gradients
if (argnames.size() == 0) {
py_grads = positional_grads;
} else {
py::dict grads_;
for (int i = 0; i < argnames.size(); i++) {
auto& k = argnames[i];
grads_[k.c_str()] = tree_unflatten(
kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
}
keyword_grads = py::cast<py::object>(grads_);
py_grads =
py::cast<py::object>(py::make_tuple(positional_grads, keyword_grads));
}
// Put the values back in the container
py::object return_value = tree_unflatten(py_value_out, value);
return std::make_pair(return_value, py_grads);
};
}
auto py_vmap(
const py::function& fun,
const py::object& in_axes,
const py::object& out_axes) {
return [fun, in_axes, out_axes](const py::args& args) {
auto axes_to_flat_tree = [](const py::object& tree,
const py::object& axes) {
auto tree_axes = tree_map(
{tree, axes},
[](const std::vector<py::object>& inputs) { return inputs[1]; });
std::vector<int> flat_axes;
tree_visit(tree_axes, [&flat_axes](py::handle obj) {
if (obj.is_none()) {
flat_axes.push_back(-1);
} else if (py::isinstance<py::int_>(obj)) {
flat_axes.push_back(py::cast<int>(py::cast<py::int_>(obj)));
} else {
throw std::invalid_argument("[vmap] axis must be int or None.");
}
});
return flat_axes;
};
// Inputs must be array or tree of arrays
auto inputs = tree_flatten(args, true);
auto flat_in_axes = axes_to_flat_tree(args, in_axes);
// py_value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values
py::object py_outputs;
auto vmap_fn =
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
// Call the python function
py_outputs = fun(*tree_unflatten(args, a));
// Flatten the outputs
return tree_flatten(py_outputs, true);
};
auto [trace_inputs, trace_outputs] =
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes);
// Perform the vmap
auto outputs = detail::vmap_replace(
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
// Put the outputs back in the container
return tree_unflatten(py_outputs, outputs);
};
}
void init_transforms(py::module_& m) {
m.def(
"eval",
[](const py::args& args, bool retain_graph) {
std::vector<array> arrays = tree_flatten(args);
eval(arrays, retain_graph);
},
"retain_graph"_a = false,
R"pbdoc(
Evaluate an :class:`array` or tree of :class:`array`.
Args:
*args (arrays or trees of arrays): Each argument can be a single array
or a tree of arrays. If a tree is given the nodes can be a Python
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
an :class:`array`.
retain_graph (bool): Indicate that the graph structure should be
preserved. This option is intended to enable function transforms
which contain control flow based on the value of an array.
)pbdoc");
m.def(
"jvp",
[](const py::function& fun,
const std::vector<array>& primals,
const std::vector<array>& tangents) {
auto vfun = [&fun](const std::vector<array>& primals) {
py::args args = py::tuple(primals.size());
for (int i = 0; i < primals.size(); ++i) {
args[i] = primals[i];
}
auto out = fun(*args);
if (py::isinstance<array>(out)) {
return std::vector<array>{py::cast<array>(out)};
} else {
return py::cast<std::vector<array>>(out);
}
};
return jvp(vfun, primals, tangents);
},
"fun"_a,
"primals"_a,
"tangents"_a,
R"pbdoc(
Compute the Jacobian-vector product.
This computes the product of the Jacobian of a function ``fun`` evaluated
at ``primals`` with the ``tangents``.
Args:
fun (function): A function which takes a variable number of :class:`array`
and returns a single :class:`array` or list of :class:`array`.
primals (list(array)): A list of :class:`array` at which to
evaluate the Jacobian.
tangents (list(array)): A list of :class:`array` which are the
"vector" in the Jacobian-vector product. The ``tangents`` should be the
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
Returns:
list(array): A list of the Jacobian-vector products which
is the same in number, shape, and type of the inputs to ``fun``.
)pbdoc");
m.def(
"vjp",
[](const py::function& fun,
const std::vector<array>& primals,
const std::vector<array>& cotangents) {
auto vfun = [&fun](const std::vector<array>& primals) {
py::args args = py::tuple(primals.size());
for (int i = 0; i < primals.size(); ++i) {
args[i] = primals[i];
}
auto out = fun(*args);
if (py::isinstance<array>(out)) {
return std::vector<array>{py::cast<array>(out)};
} else {
return py::cast<std::vector<array>>(out);
}
};
return vjp(vfun, primals, cotangents);
},
"fun"_a,
"primals"_a,
"cotangents"_a,
R"pbdoc(
Compute the vector-Jacobian product.
Computes the product of the ``cotangents`` with the Jacobian of a
function ``fun`` evaluated at ``primals``.
Args:
fun (function): A function which takes a variable number of :class:`array`
and returns a single :class:`array` or list of :class:`array`.
primals (list(array)): A list of :class:`array` at which to
evaluate the Jacobian.
cotangents (list(array)): A list of :class:`array` which are the
"vector" in the vector-Jacobian product. The ``cotangents`` should be the
same in number, shape, and type as the outputs of ``fun``.
Returns:
list(array): A list of the vector-Jacobian products which
is the same in number, shape, and type of the outputs of ``fun``.
)pbdoc");
m.def(
"value_and_grad",
[](const py::function& fun,
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
auto [argnums_vec, argnames_vec] =
validate_argnums_argnames(argnums, argnames);
return py::cpp_function(py_value_and_grad(
fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
},
"fun"_a,
"argnums"_a = std::nullopt,
"argnames"_a = std::vector<std::string>{},
R"pbdoc(
Returns a function which computes the value and gradient of ``fun``.
The function passed to :func:`value_and_grad` should return either
a scalar loss or a tuple in which the first element is a scalar
loss and the remaining elements can be anything.
.. code-block:: python
import mlx.core as mx
def mse(params, inputs, targets):
outputs = forward(params, inputs)
lvalue = (outputs - targets).square().mean()
return lvalue
# Returns lvalue, dlvalue/dparams
lvalue, grads = mx.value_and_grad(mse)
def lasso(params, inputs, targets, a=1.0, b=1.0):
outputs = forward(params, inputs)
mse = (outputs - targets).square().mean()
l1 = mx.abs(outputs - targets).mean()
loss = a*mse + b*l1
return loss, mse, l1
(loss, mse, l1), grads = mx.value_and_grad(lasso)
Args:
fun (function): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a scalar output :class:`array` or a tuple the first element
of which should be a scalar :class:`array`.
argnums (int or list(int), optional): Specify the index (or indices)
of the positional arguments of ``fun`` to compute the gradient
with respect to. If neither ``argnums`` nor ``argnames`` are
provided ``argnums`` defaults to ``0`` indicating ``fun``'s first
argument.
argnames (str or list(str), optional): Specify keyword arguments of
``fun`` to compute gradients with respect to. It defaults to [] so
no gradients for keyword arguments by default.
Returns:
function: A function which returns a tuple where the first element
is the output of `fun` and the second element is the gradients w.r.t.
the loss.
)pbdoc");
m.def(
"grad",
[](const py::function& fun,
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
auto [argnums_vec, argnames_vec] =
validate_argnums_argnames(argnums, argnames);
auto fn =
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
return py::cpp_function(
[fn](const py::args& args, const py::kwargs& kwargs) {
return fn(args, kwargs).second;
});
},
"fun"_a,
"argnums"_a = std::nullopt,
"argnames"_a = std::vector<std::string>{},
R"pbdoc(
Returns a function which computes the gradient of ``fun``.
Args:
fun (function): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a scalar output :class:`array`.
argnums (int or list(int), optional): Specify the index (or indices)
of the positional arguments of ``fun`` to compute the gradient
with respect to. If neither ``argnums`` nor ``argnames`` are
provided ``argnums`` defaults to ``0`` indicating ``fun``'s first
argument.
argnames (str or list(str), optional): Specify keyword arguments of
``fun`` to compute gradients with respect to. It defaults to [] so
no gradients for keyword arguments by default.
Returns:
function: A function which has the same input arguments as ``fun`` and
returns the gradient(s).
)pbdoc");
m.def(
"vmap",
[](const py::function& fun,
const py::object& in_axes,
const py::object& out_axes) {
return py::cpp_function(py_vmap(fun, in_axes, out_axes));
},
"fun"_a,
"in_axes"_a = 0,
"out_axes"_a = 0,
R"pbdoc(
Returns a vectorized version of ``fun``.
Args:
fun (function): A function which takes a variable number of
:class:`array` or a tree of :class:`array` and returns
a variable number of :class:`array` or a tree of :class:`array`.
in_axes (int, optional): An integer or a valid prefix tree of the
inputs to ``fun`` where each node specifies the vmapped axis. If
the value is ``None`` then the corresponding input(s) are not vmapped.
Defaults to ``0``.
out_axes (int, optional): An integer or a valid prefix tree of the
outputs of ``fun`` where each node specifies the vmapped axis. If
the value is ``None`` then the corresponding outputs(s) are not vmapped.
Defaults to ``0``.
Returns:
function: The vectorized function.
)pbdoc");
m.def(
"simplify",
[](const py::args& args) {
std::vector<array> arrays = tree_flatten(args);
simplify(arrays);
},
R"pbdoc(
Simplify the graph that computes the arrays.
Run a few fast graph simplification operations to reuse computation and
reduce memory consumption. This function is meant to be run every time
so its overhead should be small, approximately 1ms for a graph with a
few thousand nodes.
.. code-block:: python
import mlx.core as mx
def foo(x):
y = x @ x
z = x @ x
return y + z
x = mx.ones((10, 10))
y = foo(x)
z = foo(x)
# Computes the matmul twice
mx.eval(y)
# Computes the matmul once
mx.simplify(z)
mx.eval(z)
Args:
args: Any number of arrays and/or trees of arrays to be simplified.
)pbdoc");
m.def(
"export_to_dot",
[](py::object file, const py::args& args) {
std::vector<array> arrays = tree_flatten(args);
if (py::isinstance<py::str>(file)) {
std::ofstream out(py::cast<std::string>(file));
export_to_dot(out, arrays);
} else if (py::hasattr(file, "write")) {
std::ostringstream out;
export_to_dot(out, arrays);
auto write = file.attr("write");
write(out.str());
} else {
throw std::invalid_argument(
"export_to_dot accepts file-like objects or strings to be used as filenames");
}
},
"file"_a);
}

16
python/tests/mlx_tests.py Normal file
View File

@ -0,0 +1,16 @@
import os
import unittest
import mlx.core as mx
class MLXTestCase(unittest.TestCase):
def setUp(self):
self.default = mx.default_device()
device = os.getenv("DEVICE", None)
if device is not None:
device = getattr(mx, device)
mx.set_default_device(device)
def tearDown(self):
mx.set_default_device(self.default)

445
python/tests/test_blas.py Normal file
View File

@ -0,0 +1,445 @@
import unittest
from itertools import permutations
import math
import mlx.core as mx
import numpy as np
import mlx_tests
class TestBlas(mlx_tests.MLXTestCase):
@property
def dtypes(self):
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
def __gemm_test(
self,
shape_a,
shape_b,
np_dtype=np.float32,
f_np_a=lambda x: x,
f_np_b=lambda x: x,
f_mx_a=lambda x: x,
f_mx_b=lambda x: x,
):
with self.subTest(
dtype=np.dtype(np_dtype).name, shape_a=shape_a, shape_b=shape_b
):
np.random.seed(42)
scale = max(np.sum(shape_a), 128)
a_np = np.random.normal(0.0, 1.0 / scale, shape_a).astype(np_dtype)
b_np = np.random.normal(0.0, 1.0 / scale, shape_b).astype(np_dtype)
a_mx = mx.array(a_np)
b_mx = mx.array(b_np)
a_np = f_np_a(a_np.astype(np.float32))
b_np = f_np_b(b_np.astype(np.float32))
a_mx = f_mx_a(a_mx)
b_mx = f_mx_b(b_mx)
out_npy = a_np @ b_np
out_mlx = a_mx @ b_mx
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
def test_matmul_unaligned(self):
if not mx.metal.is_available():
return
for dtype in self.dtypes:
np_dtype = getattr(np, dtype)
base_shapes = [4, 8, 16, 32, 64, 128]
pertubations = [-2, -1, 0, 1, 2]
for dim in base_shapes:
for p in pertubations:
shape_a = (dim + p, dim + p)
shape_b = (dim + p, dim + p)
self.__gemm_test(shape_a, shape_b, np_dtype)
def test_matmul_shapes(self):
if not mx.metal.is_available():
return
shapes = [
(1, 2, 1, 1),
(1, 1, 2, 1),
(3, 23, 457, 3),
]
if mx.default_device() == mx.gpu:
shapes += [
(16, 768, 768, 128),
]
for dtype in self.dtypes:
np_dtype = getattr(np, dtype)
for B, M, N, K in shapes:
with self.subTest(tranpose="nn"):
shape_a = (B, M, K)
shape_b = (B, K, N)
self.__gemm_test(shape_a, shape_b, np_dtype)
with self.subTest(tranpose="nt"):
shape_a = (B, M, K)
shape_b = (B, N, K)
self.__gemm_test(
shape_a,
shape_b,
np_dtype,
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
)
with self.subTest(tranpose="tn"):
shape_a = (B, K, M)
shape_b = (B, K, N)
self.__gemm_test(
shape_a,
shape_b,
np_dtype,
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
)
with self.subTest(tranpose="tt"):
shape_a = (B, K, M)
shape_b = (B, N, K)
self.__gemm_test(
shape_a,
shape_b,
np_dtype,
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
)
def test_matmul(self):
# Note: so far, matmul only works with floating-point types
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
b = mx.array([[0.0, -1.0], [-3.0, 3.0]])
expected = [[-6.0, 5.0], [-12.0, 9.0]]
self.assertEqual((a @ b).tolist(), expected)
self.assertEqual(mx.matmul(a, b).tolist(), expected)
# Transposed matmul
np.random.seed(0)
a_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
d_npy = np.transpose(a_npy, (1, 0)) @ b_npy
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
d_mlx = mx.transpose(a_mlx, (1, 0)) @ b_mlx
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))
def test_matmul_dtypes(self):
for dt in self.dtypes:
a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
getattr(np, dt)
)
b_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
getattr(np, dt)
)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_npy = np.matmul(a_npy, b_npy, dtype=getattr(np, dt))
c_mlx = a_mlx @ b_mlx
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
def test_matmul_batched(self):
np.random.seed(0)
# Batched matmul
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
c_npy = a_npy @ b_npy
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_mlx = a_mlx @ b_mlx
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Batched and transposed matmul
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
c_npy = a_npy @ np.transpose(b_npy, (0, 2, 1))
b_mlx = mx.array(b_npy)
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 2, 1))
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Batched matmul with simple broadast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
c_npy = a_npy @ b_npy
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_mlx = a_mlx @ b_mlx
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Both operands broadcasted
d_npy = np.broadcast_to(b_npy, (5, 16, 16))
d_mlx = mx.broadcast_to(b_mlx, (5, 16, 16))
e_npy = d_npy @ d_npy
e_mlx = d_mlx @ d_mlx
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
# Batched and transposed matmul with simple broadast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
c_npy = a_npy @ b_npy
c_mlx = a_mlx @ b_mlx
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
# Test Multiheaded attention style matmul
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
a_mlx = mx.array(a_npy)
b_mlx = mx.array(b_npy)
a_npy = np.transpose(a_npy, (0, 2, 1, 3))
b_npy = np.transpose(b_npy, (0, 2, 1, 3))
a_mlx = mx.transpose(a_mlx, (0, 2, 1, 3))
b_mlx = mx.transpose(b_mlx, (0, 2, 1, 3))
c_npy = a_npy @ np.transpose(b_npy, (0, 1, 3, 2))
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 1, 3, 2))
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
def __gemv_test(
self,
shape_mat,
shape_vec,
np_dtype=np.float32,
mat_first=True,
np_mat_f=lambda x: x,
np_vec_f=lambda x: x,
mlx_mat_f=lambda x: x,
mlx_vec_f=lambda x: x,
):
with self.subTest(shape=shape_mat):
np.random.seed(42)
scale = max(np.sum(shape_mat), 32)
mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype)
vec_npy = np.random.normal(0.0, 1.0 / scale, shape_vec).astype(np_dtype)
mat_mlx = mx.array(mat_npy)
vec_mlx = mx.array(vec_npy)
mat_npy = np_mat_f(mat_npy)
vec_npy = np_vec_f(vec_npy)
mat_mlx = mlx_mat_f(mat_mlx)
vec_mlx = mlx_vec_f(vec_mlx)
if mat_first:
out_npy = mat_npy @ vec_npy
out_mlx = mat_mlx @ vec_mlx
else:
out_npy = vec_npy @ mat_npy
out_mlx = vec_mlx @ mat_mlx
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
self.assertTrue(np.allclose(out_mlx, out_npy, atol=1e-5))
def test_matrix_vector(self):
for dtype in self.dtypes:
with self.subTest(dtype=dtype):
np_dtype = getattr(np, dtype)
# Basic square matrix test
self.__gemv_test(
shape_mat=(64, 64), shape_vec=(64, 1), np_dtype=np_dtype
)
self.__gemv_test(
shape_mat=(64, 64),
shape_vec=(64, 1),
np_dtype=np_dtype,
mat_first=False,
np_vec_f=lambda x: np.transpose(x, (1, 0)),
mlx_vec_f=lambda x: mx.transpose(x, (1, 0)),
)
# Vector matrix product with aligned and unaligned shapes
for in_len_base, out_len_base in (
(2, 2),
(32, 32),
(64, 64),
(2048, 2048),
):
for mi in (-1, 0, 1):
for mj in (-1, 0, 1):
# Vec mat
shape_mat = (in_len_base + mi, out_len_base + mj)
shape_vec = (1, in_len_base + mi)
self.__gemv_test(
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
)
# Mat vec
shape_mat = (out_len_base + mj, in_len_base + mi)
shape_vec = (in_len_base + mi, 1)
self.__gemv_test(
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
)
def test_matrix_vector_batched(self):
for dtype in self.dtypes:
with self.subTest(dtype=dtype):
np_dtype = getattr(np, dtype)
# Batched mat vec
for shape_mat, shape_vec in (
((32, 128, 64), (32, 64, 1)),
((128, 64), (32, 64, 1)),
((32, 128, 64), (64, 1)),
):
self.__gemv_test(
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
)
# Batched vec mat
for shape_vec, shape_mat in (
((32, 1, 128), (32, 128, 64)),
((32, 1, 128), (128, 64)),
((1, 128), (32, 128, 64)),
):
self.__gemv_test(
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
)
def test_matrix_vector_broadcast(self):
for dtype in self.dtypes:
with self.subTest(dtype=dtype):
np_dtype = getattr(np, dtype)
# Different broadcasts mat vec
for shape_mat, shape_vec in (
((32, 64, 64), (32, 64, 1)),
((64, 64), (32, 64, 1)),
((32, 64, 64), (64, 1)),
):
self.__gemv_test(
shape_mat=(64, 64),
shape_vec=(64, 1),
np_dtype=np_dtype,
np_mat_f=(lambda mat_npy: np.broadcast_to(mat_npy, shape_mat)),
np_vec_f=(lambda vec_npy: np.broadcast_to(vec_npy, shape_vec)),
mlx_mat_f=(lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat)),
mlx_vec_f=(lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec)),
)
# Different broadcasts vec mat
for shape_vec, shape_mat in (
((32, 1, 64), (32, 64, 64)),
((32, 1, 64), (64, 64)),
((1, 64), (32, 64, 64)),
):
self.__gemv_test(
shape_mat=(64, 64),
shape_vec=(1, 64),
np_dtype=np_dtype,
mat_first=False,
np_mat_f=lambda mat_npy: np.broadcast_to(mat_npy, shape_mat),
np_vec_f=lambda vec_npy: np.broadcast_to(vec_npy, shape_vec),
mlx_mat_f=lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat),
mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),
)
def test_matrix_vector_edgecases(self):
for dtype in self.dtypes:
with self.subTest(dtype=dtype):
np_dtype = getattr(np, dtype)
for in_vec_len in np.arange(1, 5):
for out_vec_len in np.arange(1, 5):
for batch_size in np.arange(1, 5):
with self.subTest(
problem_shape=(batch_size, in_vec_len, out_vec_len)
):
# Matrix vector
with self.subTest(transpose=False):
a_npy = np.ones(
(batch_size, out_vec_len, in_vec_len),
dtype=np_dtype,
)
b_npy = np.ones(
(batch_size, in_vec_len, 1), dtype=np_dtype
)
for i in range(batch_size):
b_npy[i] *= i + 1.0
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
c_npy = a_npy @ b_npy
c_mlx = a_mlx @ b_mlx
self.assertListEqual(
list(c_npy.shape), list(c_mlx.shape)
)
self.assertTrue(np.array_equal(c_mlx, c_npy))
# Vector matrix
with self.subTest(transpose=True):
a_npy = np.ones(
(batch_size, out_vec_len, in_vec_len),
dtype=np_dtype,
)
b_npy = np.ones(
(batch_size, 1, out_vec_len), dtype=np_dtype
)
for i in range(batch_size):
b_npy[i] *= i + 1.0
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
c_npy = b_npy @ a_npy
c_mlx = b_mlx @ a_mlx
self.assertListEqual(
list(c_npy.shape), list(c_mlx.shape)
)
self.assertTrue(np.array_equal(c_mlx, c_npy))

445
python/tests/test_conv.py Normal file
View File

@ -0,0 +1,445 @@
import unittest
from itertools import permutations
import math
import mlx.core as mx
import numpy as np
import mlx_tests
try:
import torch
import torch.nn.functional as F
has_torch = True
except ImportError as e:
has_torch = False
class TestConv(mlx_tests.MLXTestCase):
def test_numpy_conv(self):
for dtype in (
"float16",
"float32",
):
np_dtype = getattr(np, dtype)
for M, N, mode in (
(1, 1, "full"),
(25, 5, "full"),
(24, 5, "same"),
(24, 4, "same"),
(24, 4, "valid"),
(4, 24, "full"),
(5, 25, "same"),
(4, 25, "valid"),
):
with self.subTest(dtype=dtype, M=M, N=N, mode=mode):
atol = 1e-6 if dtype == "float32" else 1e-5
a_np = np.random.rand(M).astype(np_dtype)
v_np = np.random.rand(N).astype(np_dtype)
a_mx = mx.array(a_np)
v_mx = mx.array(v_np)
c_np = np.convolve(a_np, v_np, mode=mode)
c_mx = mx.convolve(a_mx, v_mx, mode=mode)
self.assertListEqual(list(c_mx.shape), list(c_np.shape))
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_1D(self):
def run_conv1D(
N,
C,
O,
iH,
kH,
stride,
padding,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 2, 1)), (in_np, wt_np)
)
out_mx = mx.conv1d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv1d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertListEqual(list(out_pt.shape), out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
# Strided inputs tests
for tpose_in, tpose_wt in (
((0, 2, 1), (0, 1, 2)),
((0, 2, 1), (0, 2, 1)),
):
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_mx_t = mx.transpose(in_mx, tpose_in)
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
out_mx = mx.conv1d(in_mx_t, wt_mx_t)
in_pt, wt_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
(in_np.transpose(tpose_in), wt_np.transpose(tpose_wt)),
)
out_pt = torch.conv1d(in_pt, wt_pt)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertListEqual(list(out_pt.shape), out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_1D_grad(self):
def run_conv1D_grad(
N,
C,
O,
iH,
kH,
stride,
padding,
dilation=1,
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
iH=iH,
kH=kH,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
ct_np = np.random.normal(0, 1.0 / C, (N, oH, O)).astype(np_dtype)
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
in_pt, wt_pt, ct_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
(in_np, wt_np, ct_np),
)
def f(a, b):
return mx.conv1d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
)
pt_grad_in = F.grad.conv1d_input(
in_pt.shape,
wt_pt,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_wt = F.grad.conv1d_weight(
in_pt,
wt_pt.shape,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_in = torch.transpose(pt_grad_in, 2, 1).numpy()
pt_grad_wt = torch.transpose(pt_grad_wt, 2, 1).numpy()
mx_grad_in, mx_grad_wt = outs_mx
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
run_conv1D_grad(N, C, O, iH, kH, stride, padding, dtype=dtype)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_2D(self):
def run_conv2D(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
(in_np, wt_np),
)
out_mx = mx.conv2d(
in_mx,
wt_mx,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.conv2d(
in_pt,
wt_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertListEqual(list(out_pt.shape), list(out_mx.shape))
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_2D_grad(self):
def run_conv2D_grad(
N,
C,
O,
idim,
kdim,
stride,
padding,
dilation=(1, 1),
groups=1,
dtype="float32",
atol=1e-5,
):
with self.subTest(
dtype=dtype,
N=N,
C=C,
O=O,
idim=idim,
kdim=kdim,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C)
oH = 1 + (
(iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0]
)
oW = 1 + (
(iW + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1]
)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
ct_np = np.random.normal(0.0, scale, (N, oH, oW, O)).astype(np_dtype)
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
in_pt, wt_pt, ct_pt = map(
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
(in_np, wt_np, ct_np),
)
def f(a, b):
return mx.conv2d(
a,
b,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
_, outs_mx = mx.vjp(
f,
[
in_mx,
wt_mx,
],
[
ct_mx,
],
)
pt_grad_in = F.grad.conv1d_input(
in_pt.shape,
wt_pt,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_wt = F.grad.conv1d_weight(
in_pt,
wt_pt.shape,
ct_pt,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 1)).numpy()
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 1)).numpy()
mx_grad_in, mx_grad_wt = outs_mx
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
for N, C, O in (
(1, 1, 1),
(1, 6, 1),
(1, 1, 6),
(4, 32, 64),
):
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
if __name__ == "__main__":
unittest.main()

157
python/tests/test_load.py Normal file
View File

@ -0,0 +1,157 @@
import unittest
import os
import mlx.core as mx
import numpy as np
import tempfile
import mlx_tests
class TestLoad(mlx_tests.MLXTestCase):
dtypes = [
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float32",
"float16",
"complex64",
]
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_save_and_load(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
for dt in self.dtypes:
with self.subTest(dtype=dt):
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
with self.subTest(shape=shape):
save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy")
save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy")
save_arr = np.random.uniform(0.0, 32.0, size=shape)
save_arr_npy = save_arr.astype(getattr(np, dt))
save_arr_mlx = mx.array(save_arr_npy)
mx.save(save_file_mlx, save_arr_mlx)
np.save(save_file_npy, save_arr_npy)
# Load array saved by mlx as mlx array
load_arr_mlx_mlx = mx.load(save_file_mlx)
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
# Load array saved by numpy as mlx array
load_arr_npy_mlx = mx.load(save_file_npy)
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
# Load array saved by mlx as numpy array
load_arr_mlx_npy = np.load(save_file_mlx)
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
def test_save_and_load_fs(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
for dt in self.dtypes:
with self.subTest(dtype=dt):
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
with self.subTest(shape=shape):
save_file_mlx = os.path.join(
self.test_dir, f"mlx_{dt}_{i}_fs.npy"
)
save_file_npy = os.path.join(
self.test_dir, f"npy_{dt}_{i}_fs.npy"
)
save_arr = np.random.uniform(0.0, 32.0, size=shape)
save_arr_npy = save_arr.astype(getattr(np, dt))
save_arr_mlx = mx.array(save_arr_npy)
with open(save_file_mlx, "wb") as f:
mx.save(f, save_arr_mlx)
np.save(save_file_npy, save_arr_npy)
# Load array saved by mlx as mlx array
with open(save_file_mlx, "rb") as f:
load_arr_mlx_mlx = mx.load(f)
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
# Load array saved by numpy as mlx array
with open(save_file_npy, "rb") as f:
load_arr_npy_mlx = mx.load(f)
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
# Load array saved by mlx as numpy array
load_arr_mlx_npy = np.load(save_file_mlx)
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
def test_savez_and_loadz(self):
if not os.path.isdir(self.test_dir):
os.mkdir(self.test_dir)
for dt in self.dtypes:
with self.subTest(dtype=dt):
shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)]
save_file_mlx_uncomp = os.path.join(
self.test_dir, f"mlx_{dt}_uncomp.npz"
)
save_file_npy_uncomp = os.path.join(
self.test_dir, f"npy_{dt}_uncomp.npz"
)
save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz")
save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz")
# Make dictionary of multiple
save_arrs_npy = {
f"save_arr_{i}": np.random.uniform(
0.0, 32.0, size=shapes[i]
).astype(getattr(np, dt))
for i in range(len(shapes))
}
save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()}
# Save as npz files
np.savez(save_file_npy_uncomp, **save_arrs_npy)
mx.savez(save_file_mlx_uncomp, **save_arrs_mlx)
np.savez_compressed(save_file_npy_comp, **save_arrs_npy)
mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx)
for save_file_npy, save_file_mlx in (
(save_file_npy_uncomp, save_file_mlx_uncomp),
(save_file_npy_comp, save_file_mlx_comp),
):
# Load array saved by mlx as mlx array
load_arr_mlx_mlx = mx.load(save_file_mlx)
for k, v in load_arr_mlx_mlx.items():
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
# Load arrays saved by numpy as mlx arrays
load_arr_npy_mlx = mx.load(save_file_npy)
for k, v in load_arr_npy_mlx.items():
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
# Load array saved by mlx as numpy array
load_arr_mlx_npy = np.load(save_file_mlx)
for k, v in load_arr_mlx_npy.items():
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
if __name__ == "__main__":
unittest.main()

231
python/tests/test_nn.py Normal file
View File

@ -0,0 +1,231 @@
import unittest
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map, tree_unflatten
import numpy as np
import os
import tempfile
import mlx_tests
class TestNN(mlx_tests.MLXTestCase):
def test_linear(self):
inputs = mx.zeros((10, 4))
layer = nn.Linear(input_dims=4, output_dims=8)
outputs = layer(inputs)
self.assertEqual(tuple(outputs.shape), (10, 8))
def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])
losses = nn.losses.cross_entropy(logits, targets)
self.assertTrue(mx.array_equal(losses, mx.zeros((2,))))
def test_gelu(self):
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
# From: jax.nn.gelu(np.array(inputs), approximate=False)
expected = np.array(
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
)
out = nn.GELU()(mx.array(inputs))
self.assertTrue(np.allclose(out, expected))
# Crudely check the approximations
x = mx.arange(-6.0, 6.0, 12 / 100)
y = nn.gelu(x)
y_hat1 = nn.gelu_approx(x)
y_hat2 = nn.gelu_fast_approx(x)
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
def test_group_norm(self):
x = mx.arange(100, dtype=mx.float32)
x = x.reshape(1, 10, 10, 1)
x = mx.broadcast_to(x, (2, 10, 10, 4))
x = mx.concatenate([x, 0.5 * x], axis=-1)
# Group norm in groups last mode
g = nn.GroupNorm(2, 8)
y = g(x)
means = y.reshape(2, -1, 2).mean(axis=1)
var = y.reshape(2, -1, 2).var(axis=1)
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
g.weight = g.weight * 2
g.bias = g.bias + 3
y = g(x)
means = y.reshape(2, -1, 2).mean(axis=1)
var = y.reshape(2, -1, 2).var(axis=1)
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
# Group norm in groups first mode
g = nn.GroupNorm(2, 8, pytorch_compatible=True)
y = g(x)
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
g.weight = g.weight * 2
g.bias = g.bias + 3
y = g(x)
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
def test_conv1d(self):
N = 5
L = 12
ks = 3
C_in = 2
C_out = 4
x = mx.ones((N, L, C_in))
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)
c.weight = mx.ones_like(c.weight)
y = c(x)
self.assertEqual(y.shape, [N, L - ks + 1, C_out])
self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)
y = c(x)
self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out])
self.assertTrue("bias" in c.parameters())
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
self.assertTrue("bias" not in c.parameters())
def test_conv2d(self):
x = mx.ones((4, 8, 8, 3))
c = nn.Conv2d(3, 1, 8)
y = c(x)
self.assertEqual(y.shape, [4, 1, 1, 1])
c.weight = mx.ones_like(c.weight) / 8 / 8 / 3
y = c(x)
self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))
# 3x3 conv no padding stride 1
c = nn.Conv2d(3, 8, 3)
y = c(x)
self.assertEqual(y.shape, [4, 6, 6, 8])
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
# 3x3 conv padding 1 stride 1
c = nn.Conv2d(3, 8, 3, padding=1)
y = c(x)
self.assertEqual(y.shape, [4, 8, 8, 8])
self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)
self.assertLess(
mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),
1e-4,
)
self.assertLess(
mx.abs(y[:, 7, 7] - c.weight[:, :-1, :-1].sum(axis=(1, 2, 3))).max(),
1e-4,
)
self.assertLess(
mx.abs(y[:, 1:7, 7] - c.weight[:, :, :-1].sum(axis=(1, 2, 3))).max(),
1e-4,
)
self.assertLess(
mx.abs(y[:, 7, 1:7] - c.weight[:, :-1, :].sum(axis=(1, 2, 3))).max(),
1e-4,
)
# 3x3 conv no padding stride 2
c = nn.Conv2d(3, 8, 3, padding=0, stride=2)
y = c(x)
self.assertEqual(y.shape, [4, 3, 3, 8])
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
def test_sequential(self):
x = mx.ones((10, 2))
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))
y = m(x)
self.assertEqual(y.shape, [10, 1])
params = m.parameters()
self.assertTrue("layers" in params)
self.assertEqual(len(params["layers"]), 3)
self.assertTrue("weight" in params["layers"][0])
self.assertEqual(len(params["layers"][1]), 0)
self.assertTrue("weight" in params["layers"][2])
m.layers[1] = nn.relu
y2 = m(x)
self.assertTrue(mx.array_equal(y, y2))
def test_module_utilities(self):
m = nn.Sequential(
nn.Sequential(nn.Linear(2, 10), nn.relu),
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
nn.Linear(10, 1),
mx.sigmoid,
)
children = m.children()
self.assertTrue(isinstance(children, dict))
self.assertEqual(len(children), 1)
self.assertTrue(isinstance(children["layers"], list))
self.assertEqual(len(children["layers"]), 4)
self.assertEqual(children["layers"][3], {})
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
self.assertEqual(len(flat_children), 3)
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
self.assertEqual(len(leaves), 4)
self.assertEqual(leaves[0][0], "layers.0.layers.0")
self.assertEqual(leaves[1][0], "layers.1.layers.0")
self.assertEqual(leaves[2][0], "layers.1.layers.1")
self.assertEqual(leaves[3][0], "layers.2")
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
self.assertTrue(leaves[3][1] is m.layers[2])
m.eval()
def assert_not_training(k, m):
self.assertFalse(m.training)
m.apply_to_modules(assert_not_training)
m.train()
def assert_training(k, m):
self.assertTrue(m.training)
m.apply_to_modules(assert_training)
def test_sin_pe(self):
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
x = mx.arange(10)
y = m(x)
self.assertEqual(y.shape, [10, 16])
similarities = y @ y.T
self.assertLess(
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
)
def test_io(self):
def make_model():
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
m = make_model()
tdir = tempfile.TemporaryDirectory()
file = os.path.join(tdir.name, "model.npz")
m.save_weights(file)
m_load = make_model()
m_load.load_weights(file)
tdir.cleanup()
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,29 @@
import unittest
import mlx.core as mx
import mlx.optimizers as opt
import mlx.utils
import mlx_tests
class TestOptimizers(mlx_tests.MLXTestCase):
def test_optimizers(self):
params = {
"first": [mx.zeros((10,)), mx.zeros((1,))],
"second": mx.zeros((1,)),
}
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
for optim in [opt.SGD(0.1), opt.Adam(0.1)]:
update = optim.apply_gradients(grads, params)
mx.eval(update)
equal_shape = mlx.utils.tree_map(
lambda x, y: x.shape == y.shape, params, update
)
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
self.assertTrue(all_equal)
if __name__ == "__main__":
unittest.main()

192
python/tests/test_random.py Normal file
View File

@ -0,0 +1,192 @@
import unittest
import mlx.core as mx
import mlx_tests
class TestRandom(mlx_tests.MLXTestCase):
def test_global_rng(self):
mx.random.seed(3)
a = mx.random.uniform()
b = mx.random.uniform()
mx.random.seed(3)
x = mx.random.uniform()
y = mx.random.uniform()
self.assertEqual(a.item(), x.item())
self.assertEqual(y.item(), b.item())
def test_key(self):
k1 = mx.random.key(0)
k2 = mx.random.key(0)
self.assertTrue(mx.array_equal(k1, k2))
k2 = mx.random.key(1)
self.assertFalse(mx.array_equal(k1, k2))
def test_key_split(self):
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
self.assertFalse(mx.array_equal(k1, k2))
r1, r2 = mx.random.split(key)
self.assertTrue(mx.array_equal(k1, r1))
self.assertTrue(mx.array_equal(k2, r2))
keys = mx.random.split(key, 10)
self.assertEqual(keys.shape, [10, 2])
def test_uniform(self):
key = mx.random.key(0)
a = mx.random.uniform(key=key)
self.assertEqual(a.shape, [])
self.assertEqual(a.dtype, mx.float32)
b = mx.random.uniform(key=key)
self.assertEqual(a.item(), b.item())
a = mx.random.uniform(shape=(2, 3))
self.assertEqual(a.shape, [2, 3])
a = mx.random.uniform(shape=(1000,), low=-1, high=5)
self.assertTrue(mx.all((a > -1) < 5).item())
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
self.assertTrue(mx.all((a > -1) < 5).item())
def test_normal(self):
key = mx.random.key(0)
a = mx.random.normal(key=key)
self.assertEqual(a.shape, [])
self.assertEqual(a.dtype, mx.float32)
b = mx.random.normal(key=key)
self.assertEqual(a.item(), b.item())
a = mx.random.normal(shape=(2, 3))
self.assertEqual(a.shape, [2, 3])
## Generate in float16 or bfloat16
for t in [mx.float16, mx.bfloat16]:
a = mx.random.normal(dtype=t)
self.assertEqual(a.dtype, t)
def test_randint(self):
a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, [])
self.assertEqual(a.dtype, mx.int32)
shape = [88]
low = mx.array(3)
high = mx.array(15)
key = mx.random.key(0)
a = mx.random.randint(low, high, shape, key=key)
self.assertEqual(a.shape, shape)
self.assertEqual(a.dtype, mx.int32)
# Check using the same key yields the same value
b = mx.random.randint(low, high, shape, key=key)
self.assertListEqual(a.tolist(), b.tolist())
shape = [3, 4]
low = mx.reshape(mx.array([0] * 3), [3, 1])
high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4])
a = mx.random.randint(low, high, shape)
self.assertEqual(a.shape, shape)
a = mx.random.randint(-10, 10, [1000, 1000])
self.assertTrue(mx.all(-10 <= a).item() and mx.all(a < 10).item())
a = mx.random.randint(10, -10, [1000, 1000])
self.assertTrue(mx.all(a == 10).item())
def test_bernoulli(self):
a = mx.random.bernoulli()
self.assertEqual(a.shape, [])
self.assertEqual(a.dtype, mx.bool_)
a = mx.random.bernoulli(mx.array(0.5), [5])
self.assertEqual(a.shape, [5])
a = mx.random.bernoulli(mx.array([2.0, -2.0]))
self.assertEqual(a.tolist(), [True, False])
self.assertEqual(a.shape, [2])
p = mx.array([0.1, 0.2, 0.3])
mx.reshape(p, [1, 3])
x = mx.random.bernoulli(p, [4, 3])
self.assertEqual(x.shape, [4, 3])
with self.assertRaises(ValueError):
mx.random.bernoulli(p, [2]) # Bad shape
with self.assertRaises(ValueError):
mx.random.bernoulli(0, [2]) # Bad type
def test_truncated_normal(self):
a = mx.random.truncated_normal(-2.0, 2.0)
self.assertEqual(a.size, 1)
self.assertEqual(a.dtype, mx.float32)
a = mx.random.truncated_normal(mx.array([]), mx.array([]))
self.assertEqual(a.dtype, mx.float32)
self.assertEqual(a.size, 0)
lower = mx.reshape(mx.array([-2.0, 0.0]), [1, 2])
upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1])
a = mx.random.truncated_normal(lower, upper)
self.assertEqual(a.shape, [3, 2])
self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item())
a = mx.random.truncated_normal(2.0, -2.0)
self.assertTrue(mx.all(a == 2.0).item())
a = mx.random.truncated_normal(-3.0, 3.0, [542, 399])
self.assertEqual(a.shape, [542, 399])
lower = mx.array([-2.0, -1.0])
higher = mx.array([1.0, 2.0, 3.0])
with self.assertRaises(ValueError):
mx.random.truncated_normal(lower, higher) # Bad shape
def test_gumbel(self):
samples = mx.random.gumbel(shape=(100, 100))
self.assertEqual(samples.shape, [100, 100])
self.assertEqual(samples.dtype, mx.float32)
mean = 0.5772
# Std deviation of the sample mean is small (<0.02),
# so this test is pretty conservative
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
def test_categorical(self):
logits = mx.zeros((10, 20))
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
self.assertEqual(mx.random.categorical(logits, 0).shape, [20])
self.assertEqual(mx.random.categorical(logits, 1).shape, [10])
out = mx.random.categorical(logits)
self.assertEqual(out.shape, [10])
self.assertEqual(out.dtype, mx.uint32)
self.assertTrue(mx.max(out).item() < 20)
out = mx.random.categorical(logits, 0, [5, 20])
self.assertEqual(out.shape, [5, 20])
self.assertTrue(mx.max(out).item() < 10)
out = mx.random.categorical(logits, 1, num_samples=7)
self.assertEqual(out.shape, [10, 7])
out = mx.random.categorical(logits, 0, num_samples=7)
self.assertEqual(out.shape, [20, 7])
with self.assertRaises(ValueError):
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
if __name__ == "__main__":
unittest.main()

26
python/tests/test_tree.py Normal file
View File

@ -0,0 +1,26 @@
import unittest
import mlx.core as mx
import mlx.utils
import mlx_tests
class TestTreeUtils(mlx_tests.MLXTestCase):
def test_tree_map(self):
tree = {"a": 0, "b": 1, "c": 2}
tree = mlx.utils.tree_map(lambda x: x + 1, tree)
expected_tree = {"a": 1, "b": 2, "c": 3}
self.assertEqual(tree, expected_tree)
def test_tree_flatten(self):
tree = [{"a": 1, "b": 2}, "c"]
vals = (1, 2, "c")
flat_tree = mlx.utils.tree_flatten(tree)
self.assertEqual(list(zip(*flat_tree))[1], vals)
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
if __name__ == "__main__":
unittest.main()

167
python/tests/test_vmap.py Normal file
View File

@ -0,0 +1,167 @@
import unittest
import mlx.core as mx
import mlx_tests
class TestVmap(mlx_tests.MLXTestCase):
def test_basics(self):
# Can't vmap over scalars
with self.assertRaises(ValueError):
mx.vmap(mx.exp)(mx.array(1.0))
# Invalid input
with self.assertRaises(ValueError):
mx.vmap(mx.exp)("hello")
# Invalid axes
with self.assertRaises(ValueError):
mx.vmap(mx.exp, in_axes="hello")(mx.array([0, 1]))
with self.assertRaises(ValueError):
mx.vmap(mx.exp, in_axes=2)(mx.array([0, 1]))
with self.assertRaises(ValueError):
mx.vmap(mx.exp, out_axes="hello")(mx.array([0, 1]))
with self.assertRaises(ValueError):
mx.vmap(mx.exp, out_axes=2)(mx.array([0, 1]))
def test_unary(self):
ops = [
"abs",
"cos",
"erf",
"erfinv",
"exp",
"log",
"log1p",
"log2",
"log10",
"logical_not",
"negative",
"reciprocal",
"rsqrt",
"sigmoid",
"sign",
"sin",
"sqrt",
"square",
]
ops = ["erfinv"]
for opname in ops:
with self.subTest(op=opname):
op = getattr(mx, opname)
x = mx.arange(5)
y = mx.vmap(op)(x)
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
x = mx.arange(8).reshape(2, 4)
y = mx.vmap(op)(x)
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
y = mx.vmap(op, in_axes=1, out_axes=1)(x)
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
def test_binary(self):
ops = [
"add",
"divide",
"equal",
"greater",
"greater_equal",
"less",
"less_equal",
"logaddexp",
"maximum",
"minimum",
"multiply",
"power",
"subtract",
]
for opname in ops:
with self.subTest(op=opname):
op = getattr(mx, opname)
x = mx.random.uniform(shape=(5,))
y = mx.random.uniform(shape=(5,))
out = mx.vmap(op)(x, y)
self.assertTrue(mx.array_equal(out, op(x, y)))
x = mx.random.uniform(shape=(2, 4))
y = mx.random.uniform(shape=(2, 4))
out = mx.vmap(op)(x, y)
self.assertTrue(mx.array_equal(out, op(x, y)))
out = mx.vmap(op, in_axes=(0, 0), out_axes=0)(x, y)
self.assertTrue(mx.array_equal(out, op(x, y)))
y = mx.random.uniform(shape=(4, 2))
out = mx.vmap(op, in_axes=(0, 1), out_axes=0)(x, y)
self.assertTrue(mx.array_equal(out, op(x, y.T)))
out = mx.vmap(op, in_axes=(0, 1), out_axes=1)(x, y)
self.assertTrue(mx.array_equal(out, op(x, y.T).T))
def test_tree(self):
def my_fun(tree):
return (tree["a"] + tree["b"][0]) * tree["b"][1]
tree = {
"a": mx.random.uniform(shape=(2, 4)),
"b": (
mx.random.uniform(shape=(2, 4)),
mx.random.uniform(shape=(2, 4)),
),
}
out = mx.vmap(my_fun)(tree)
expected = my_fun(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
self.assertTrue(mx.array_equal(out, my_fun(tree)))
tree = {
"a": mx.random.uniform(shape=(2, 4)),
"b": (
mx.random.uniform(shape=(4, 2)),
mx.random.uniform(shape=(4, 2)),
),
}
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
self.assertTrue(mx.array_equal(out, expected))
def my_fun(x, y):
return {"a": x + y, "b": x * y}
x = mx.random.uniform(shape=(2, 4))
y = mx.random.uniform(shape=(2, 4))
out = mx.vmap(my_fun, in_axes=0, out_axes=0)(x, y)
expected = my_fun(x, y)
self.assertTrue(mx.array_equal(out["a"], expected["a"]))
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes=0, out_axes=(0, 1))(x, y)
with self.assertRaises(ValueError):
mx.vmap(my_fun, in_axes=0, out_axes={"a": 0, "c": 1})(x, y)
out = mx.vmap(my_fun, in_axes=0, out_axes={"a": 1, "b": 0})(x, y)
expected = my_fun(x, y)
self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
if __name__ == "__main__":
unittest.main()

127
setup.py Normal file
View File

@ -0,0 +1,127 @@
import os
import re
import subprocess
import sys
import sysconfig
from pathlib import Path
from setuptools import Extension, setup, find_namespace_packages
from setuptools.command.build_ext import build_ext
# A CMakeExtension needs a sourcedir instead of a file list.
# The name must be the _single_ output extension from the CMake build.
# If you need multiple extensions, see scikit-build.
class CMakeExtension(Extension):
def __init__(self, name: str, sourcedir: str = "") -> None:
super().__init__(name, sources=[])
self.sourcedir = os.fspath(Path(sourcedir).resolve())
class CMakeBuild(build_ext):
def build_extension(self, ext: CMakeExtension) -> None:
# Must be in this form due to bug in .resolve() only fixed in Python 3.10+
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call]
extdir = ext_fullpath.parent.resolve()
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
cfg = "Debug" if debug else "Release"
# CMake lets you override the generator - we need to check this.
# Can be set with Conda-Build, for example.
cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
# Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
# from Python.
cmake_args = [
f"-DCMAKE_INSTALL_PREFIX={extdir}{os.sep}",
f"-DCMAKE_BUILD_TYPE={cfg}",
"-DBUILD_SHARED_LIBS=ON",
"-DMLX_BUILD_PYTHON_BINDINGS=ON",
"-DMLX_BUILD_TESTS=OFF",
"-DMLX_BUILD_BENCHMARKS=OFF",
"-DMLX_BUILD_EXAMPLES=OFF",
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
]
build_args = []
# Adding CMake arguments set as environment variable
# (needed e.g. to build for ARM OSx on conda-forge)
if "CMAKE_ARGS" in os.environ:
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
# Pass version to C++
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
if sys.platform.startswith("darwin"):
# Cross-compile support for macOS - respect ARCHFLAGS if set
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
if archs:
cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
# across all generators.
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
# self.parallel is a Python 3 only way to set parallel jobs by hand
# using -j in the build_ext call, not supported by pip or PyPA-build.
if hasattr(self, "parallel") and self.parallel:
# CMake 3.12+ only.
build_args += [f"-j{self.parallel}"]
build_temp = Path(self.build_temp) / ext.name
if not build_temp.exists():
build_temp.mkdir(parents=True)
subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
)
subprocess.run(
["cmake", "--build", ".", "--target", "install", *build_args],
cwd=build_temp,
check=True,
)
# Make sure to copy mlx.metallib for inplace builds
def run(self):
super().run()
# Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102
if self.inplace:
for ext in self.extensions:
if ext.name == "mlx.core":
# Resolve inplace package dir
build_py = self.get_finalized_command("build_py")
inplace_file, regular_file = self._get_inplace_equivalent(
build_py, ext
)
inplace_dir = str(Path(inplace_file).parent.resolve())
regular_dir = str(Path(regular_file).parent.resolve())
self.copy_tree(regular_dir, inplace_dir)
# The information here can also be placed in setup.cfg - better separation of
# logic and declaration, and simpler if you include description/version in a file.
if __name__ == "__main__":
packages = find_namespace_packages(
where="python", exclude=["src", "tests", "tests.*"]
)
package_dir = {"": "python"}
package_data = {"mlx": ["lib/*", "include/*", "share/*"]}
setup(
name="mlx",
version="0.0.2",
author="MLX Contributors",
author_email="mlx@group.apple.com",
description="A framework for machine learning on Apple Silicon.",
long_description="",
packages=packages,
package_dir=package_dir,
package_data=package_data,
include_package_data=True,
ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild},
zip_safe=False,
python_requires=">=3.7",
)

41
tests/allocator_tests.cpp Normal file
View File

@ -0,0 +1,41 @@
#include <stdexcept>
#include "doctest/doctest.h"
#include "mlx/allocator.h"
using namespace mlx::core;
TEST_CASE("test simple allocations") {
{
auto buffer = allocator::malloc(sizeof(float));
auto fptr = static_cast<float*>(buffer.raw_ptr());
*fptr = 0.5f;
CHECK_EQ(*fptr, 0.5f);
allocator::free(buffer);
}
{
auto buffer = allocator::malloc(128 * sizeof(int));
int* ptr = static_cast<int*>(buffer.raw_ptr());
for (int i = 0; i < 128; ++i) {
ptr[i] = i;
}
allocator::free(buffer);
}
{
auto buffer = allocator::malloc(0);
allocator::free(buffer);
}
}
TEST_CASE("test large allocations") {
size_t size = 1 << 30;
for (int i = 0; i < 100; ++i) {
auto buffer = allocator::malloc(size);
allocator::free(buffer);
}
// Shouldn't be able to allocate an exabyte anytime soon.
CHECK_THROWS_AS(allocator::malloc(1ull << 60), std::runtime_error);
}

589
tests/array_tests.cpp Normal file
View File

@ -0,0 +1,589 @@
#include <climits>
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test array basics") {
// Scalar
array x(1.0);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.ndim(), 0);
CHECK_EQ(x.shape(), std::vector<int>{});
CHECK_THROWS_AS(x.shape(0), std::out_of_range);
CHECK_THROWS_AS(x.shape(-1), std::out_of_range);
CHECK_EQ(x.strides(), std::vector<size_t>{});
CHECK_EQ(x.itemsize(), sizeof(float));
CHECK_EQ(x.nbytes(), sizeof(float));
CHECK_EQ(x.dtype(), float32);
CHECK_EQ(x.item<float>(), 1.0);
// Scalar with specified type
x = array(1, float32);
CHECK_EQ(x.dtype(), float32);
CHECK_EQ(x.item<float>(), 1.0);
// Scalar with specified type
x = array(1, bool_);
CHECK_EQ(x.dtype(), bool_);
CHECK_EQ(x.itemsize(), sizeof(bool));
CHECK_EQ(x.nbytes(), sizeof(bool));
CHECK_EQ(x.item<bool>(), true);
// Check shaped arrays
x = array({1.0});
CHECK_EQ(x.dtype(), float32);
CHECK_EQ(x.size(), 1);
CHECK_EQ(x.ndim(), 1);
CHECK_EQ(x.shape(), std::vector<int>{1});
CHECK_EQ(x.shape(0), 1);
CHECK_EQ(x.shape(-1), 1);
CHECK_THROWS_AS(x.shape(1), std::out_of_range);
CHECK_THROWS_AS(x.shape(-2), std::out_of_range);
CHECK_EQ(x.strides(), std::vector<size_t>{1});
CHECK_EQ(x.item<float>(), 1.0);
// Check empty array
x = array({});
CHECK_EQ(x.size(), 0);
CHECK_EQ(x.dtype(), float32);
CHECK_EQ(x.itemsize(), sizeof(float));
CHECK_EQ(x.nbytes(), 0);
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
x = array({1.0, 1.0});
CHECK_EQ(x.size(), 2);
CHECK_EQ(x.shape(), std::vector<int>{2});
CHECK_EQ(x.itemsize(), sizeof(float));
CHECK_EQ(x.nbytes(), x.itemsize() * x.size());
// Accessing item in non-scalar array throws
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
x = array({1.0, 1.0, 1.0}, {1, 3});
CHECK(x.size() == 3);
CHECK(x.shape() == std::vector<int>{1, 3});
CHECK(x.strides() == std::vector<size_t>{3, 1});
// Test wrong size/shapes throw:
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument);
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 4}), std::invalid_argument);
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 2}), std::invalid_argument);
// Test array ids work as expected
x = array(1.0);
auto y = x;
CHECK_EQ(y.id(), x.id());
array z(2.0);
CHECK_NE(z.id(), x.id());
z = x;
CHECK_EQ(z.id(), x.id());
// Array creation from pointer
float data[] = {0.0, 1.0, 2.0, 3.0};
x = array(data, {4});
CHECK_EQ(x.dtype(), float32);
CHECK(array_equal(x, array({0.0, 1.0, 2.0, 3.0})).item<bool>());
// Array creation from vectors
{
std::vector<int> data = {0, 1, 2, 3};
x = array(data.begin(), {4});
CHECK_EQ(x.dtype(), int32);
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
}
{
std::vector<bool> data = {false, true, false, true};
x = array(data.begin(), {4});
CHECK_EQ(x.dtype(), bool_);
CHECK(array_equal(x, array({false, true, false, true})).item<bool>());
}
}
TEST_CASE("test array types") {
#define basic_dtype_test(T, mlx_type) \
T val = 42; \
array x(val); \
CHECK_EQ(x.dtype(), mlx_type); \
CHECK_EQ(x.item<T>(), val); \
x = array({val, val}); \
CHECK_EQ(x.dtype(), mlx_type);
// bool_
{
array x(true);
CHECK_EQ(x.dtype(), bool_);
CHECK_EQ(x.item<bool>(), true);
x = array({true, false});
CHECK_EQ(x.dtype(), bool_);
x = array({true, false}, float32);
CHECK_EQ(x.dtype(), float32);
CHECK(array_equal(x, array({1.0f, 0.0f})).item<bool>());
}
// uint8
{ basic_dtype_test(uint8_t, uint8); }
// uint16
{ basic_dtype_test(uint16_t, uint16); }
// uint32
{ basic_dtype_test(uint32_t, uint32); }
// uint64
{ basic_dtype_test(uint64_t, uint64); }
// int8
{ basic_dtype_test(int8_t, int8); }
// int16
{ basic_dtype_test(int16_t, int16); }
// int32
{ basic_dtype_test(int32_t, int32); }
// int64
{ basic_dtype_test(int64_t, int64); }
// float16
{ basic_dtype_test(float16_t, float16); }
// float32
{ basic_dtype_test(float, float32); }
// bfloat16
{ basic_dtype_test(bfloat16_t, bfloat16); }
// uint32
{
uint32_t val = UINT_MAX;
array x(val);
CHECK_EQ(x.dtype(), uint32);
CHECK_EQ(x.item<uint32_t>(), val);
x = array({1u, 2u});
CHECK_EQ(x.dtype(), uint32);
}
// int32
{
array x(-1);
CHECK_EQ(x.dtype(), int32);
CHECK_EQ(x.item<int>(), -1);
x = array({-1, 2});
CHECK_EQ(x.dtype(), int32);
std::vector<int> data{0, 1, 2};
x = array(data.data(), {static_cast<int>(data.size())}, bool_);
CHECK_EQ(x.dtype(), bool_);
CHECK(array_equal(x, array({false, true, true})).item<bool>());
}
// int64
{
int64_t val = static_cast<int64_t>(INT_MIN) - 1;
array x(val);
CHECK_EQ(x.dtype(), int64);
CHECK_EQ(x.item<int64_t>(), val);
x = array({val, val});
CHECK_EQ(x.dtype(), int64);
}
// float32
{
array x(3.14f);
CHECK_EQ(x.dtype(), float32);
CHECK_EQ(x.item<float>(), 3.14f);
x = array(1.25);
CHECK_EQ(x.dtype(), float32);
CHECK_EQ(x.item<float>(), 1.25f);
x = array({1.0f, 2.0f});
CHECK_EQ(x.dtype(), float32);
x = array({1.0, 2.0});
CHECK_EQ(x.dtype(), float32);
std::vector<double> data{1.0, 2.0, 4.0};
x = array(data.data(), {static_cast<int>(data.size())});
CHECK_EQ(x.dtype(), float32);
CHECK(array_equal(x, array({1.0f, 2.0f, 4.0f})).item<bool>());
}
// complex64
{
complex64_t v = {1.0f, 1.0f};
array x(v);
CHECK_EQ(x.dtype(), complex64);
CHECK_EQ(x.item<complex64_t>(), v);
array y(std::complex<float>{1.0f, 1.0f});
CHECK_EQ(x.dtype(), complex64);
CHECK_EQ(x.item<complex64_t>(), v);
}
#undef basic_dtype_test
#define basic_dtype_str_test(s, dtype) \
CHECK_EQ(s, dtype_to_array_protocol(dtype)); \
CHECK_EQ(dtype_from_array_protocol(s), dtype);
// To and from str
{
basic_dtype_str_test("|b1", bool_);
basic_dtype_str_test("|u1", uint8);
basic_dtype_str_test("<u2", uint16);
basic_dtype_str_test("<u4", uint32);
basic_dtype_str_test("<u8", uint64);
basic_dtype_str_test("|i1", int8);
basic_dtype_str_test("<i2", int16);
basic_dtype_str_test("<i4", int32);
basic_dtype_str_test("<i8", int64);
basic_dtype_str_test("<f2", float16);
basic_dtype_str_test("<f4", float32);
basic_dtype_str_test("<V2", bfloat16);
basic_dtype_str_test("<c8", complex64);
}
#undef basic_dtype_str_test
}
TEST_CASE("test array metadata") {
array x(1.0f);
CHECK_EQ(x.data_size(), 1);
CHECK_EQ(x.flags().contiguous, true);
CHECK_EQ(x.flags().row_contiguous, true);
CHECK_EQ(x.flags().col_contiguous, true);
x = array({1.0f}, {1, 1, 1});
CHECK_EQ(x.data_size(), 1);
CHECK_EQ(x.flags().contiguous, true);
CHECK_EQ(x.flags().row_contiguous, true);
CHECK_EQ(x.flags().col_contiguous, true);
x = array({1.0f, 1.0f}, {1, 2});
CHECK_EQ(x.data_size(), 2);
CHECK_EQ(x.flags().contiguous, true);
CHECK_EQ(x.flags().row_contiguous, true);
CHECK_EQ(x.flags().col_contiguous, true);
x = zeros({1, 1, 4});
eval(x);
CHECK_EQ(x.data_size(), 4);
CHECK_EQ(x.flags().contiguous, true);
CHECK_EQ(x.flags().row_contiguous, true);
CHECK_EQ(x.flags().col_contiguous, true);
x = zeros({2, 4});
eval(x);
CHECK_EQ(x.data_size(), 8);
CHECK_EQ(x.flags().contiguous, true);
CHECK_EQ(x.flags().row_contiguous, true);
CHECK_EQ(x.flags().col_contiguous, false);
x = array(1.0f);
auto y = broadcast_to(x, {1, 1, 1});
eval(y);
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
y = broadcast_to(x, {2, 8, 10});
eval(y);
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, false);
CHECK_EQ(y.flags().col_contiguous, false);
y = broadcast_to(x, {1, 0});
eval(y);
CHECK_EQ(y.data_size(), 0);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
y = broadcast_to(zeros({4, 2, 1}), {4, 2, 0});
eval(y);
CHECK_EQ(y.data_size(), 0);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = array(1.0f);
y = transpose(x);
eval(y);
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = ones({1, 1, 1});
y = transpose(x);
eval(y);
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = ones({1, 1, 1});
y = transpose(x, {0, 1, 2});
eval(y);
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = ones({1, 1, 1});
y = transpose(x, {1, 2, 0});
eval(y);
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = ones({4, 1});
y = transpose(x);
eval(y);
CHECK_EQ(y.data_size(), 4);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = ones({2, 3, 4});
y = transpose(x);
eval(y);
CHECK_EQ(y.data_size(), 24);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, false);
CHECK_EQ(y.flags().col_contiguous, true);
y = transpose(x, {0, 2, 1});
eval(y);
CHECK_EQ(y.data_size(), 24);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, false);
CHECK_EQ(y.flags().col_contiguous, false);
y = transpose(transpose(x, {0, 2, 1}), {0, 2, 1});
eval(y);
CHECK_EQ(y.data_size(), 24);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, false);
x = array(1.0f);
y = reshape(x, {1, 1, 1});
eval(y);
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = ones({2, 4});
y = reshape(x, {8});
eval(y);
CHECK_EQ(y.data_size(), 8);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
y = reshape(x, {8, 1, 1});
eval(y);
CHECK_EQ(y.data_size(), 8);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
y = reshape(x, {1, 8, 1});
eval(y);
CHECK_EQ(y.data_size(), 8);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = ones({12});
y = reshape(x, {2, 3, 2});
eval(y);
CHECK_EQ(y.data_size(), 12);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, false);
x = array(1.0f);
y = slice(x, {}, {});
eval(y);
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = array({1.0f});
y = slice(x, {-10}, {10}, {10});
eval(y);
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
y = slice(x, {0, 0}, {1, 3}, {1, 1});
eval(y);
CHECK_EQ(y.data_size(), 3);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
y = slice(x, {0, 0}, {1, 3}, {1, 1});
eval(y);
CHECK_EQ(y.data_size(), 3);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
y = slice(x, {0, 0}, {0, 3}, {1, 1});
eval(y);
CHECK_EQ(y.data_size(), 0);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
y = slice(x, {0, 0}, {1, 2}, {1, 1});
eval(y);
CHECK_EQ(y.data_size(), 2);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
y = slice(x, {0, 0}, {1, 2}, {2, 3});
eval(y);
CHECK_EQ(y.shape(), std::vector<int>{1, 1});
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4});
y = slice(x, {0, 0}, {1, 4}, {1, 2});
eval(y);
CHECK_EQ(y.shape(), std::vector<int>{1, 2});
CHECK_EQ(y.flags().contiguous, false);
CHECK_EQ(y.flags().row_contiguous, false);
CHECK_EQ(y.flags().col_contiguous, false);
x = broadcast_to(array(1.0f), {4, 10});
y = slice(x, {0, 0}, {4, 10}, {2, 2});
eval(y);
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
CHECK_EQ(y.data_size(), 1);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, false);
CHECK_EQ(y.flags().col_contiguous, false);
x = broadcast_to(array({1.0f, 2.0f}), {4, 2});
y = slice(x, {0, 0}, {1, 2}, {1, 1});
eval(y);
CHECK_EQ(y.data_size(), 2);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
y = slice(x, {1, 0}, {2, 2}, {1, 1});
eval(y);
CHECK_EQ(y.data_size(), 2);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
y = slice(x, {0, 0}, {2, 2}, {1, 1});
eval(y);
CHECK_EQ(y.data_size(), 4);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, false);
y = slice(transpose(x), {0, 0}, {2, 2}, {1, 1});
eval(y);
CHECK_EQ(y.data_size(), 4);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, false);
CHECK_EQ(y.flags().col_contiguous, true);
x = ones({2, 4});
auto out = split(x, 2);
eval(out);
for (auto y : out) {
CHECK_EQ(y.data_size(), 4);
CHECK_EQ(y.flags().contiguous, true);
CHECK_EQ(y.flags().row_contiguous, true);
CHECK_EQ(y.flags().col_contiguous, true);
}
out = split(x, 4, 1);
eval(out);
for (auto y : out) {
CHECK_EQ(y.flags().contiguous, false);
CHECK_EQ(y.flags().row_contiguous, false);
CHECK_EQ(y.flags().col_contiguous, false);
}
}
TEST_CASE("test array iteration") {
// Dim 0 arrays
auto arr = array(1);
CHECK_THROWS(arr.begin());
// Iterated arrays are read only
CHECK(std::is_const_v<decltype(*arr.begin())>);
arr = array({1, 2, 3, 4, 5});
int i = 0;
for (auto a : arr) {
i++;
CHECK_EQ(a.item<int>(), i);
}
CHECK_EQ(i, 5);
arr = array({1, 2, 3, 4}, {2, 2});
CHECK(array_equal(*arr.begin(), array({1, 2})).item<bool>());
CHECK(array_equal(*(arr.begin() + 1), array({3, 4})).item<bool>());
CHECK_EQ(arr.begin() + 2, arr.end());
}
TEST_CASE("test array shared buffer") {
std::vector<int> shape = {2, 2};
int n_elem = shape[0] * shape[1];
allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float));
void* buf_b_ptr = buf_b.raw_ptr();
float* float_buf_b = (float*)buf_b_ptr;
for (int i = 0; i < n_elem; i++) {
float_buf_b[i] = 2.;
}
CHECK_EQ(float_buf_b[0], ((float*)buf_b_ptr)[0]);
auto deleter = [float_buf_b](allocator::Buffer buf) {
CHECK_EQ(float_buf_b, (float*)buf.raw_ptr());
CHECK_EQ(float_buf_b[0], ((float*)buf.raw_ptr())[0]);
allocator::free(buf);
};
array a = ones(shape, float32);
array b = array(buf_b, shape, float32, deleter);
eval(a + b);
}

1192
tests/autograd_tests.cpp Normal file

File diff suppressed because it is too large Load Diff

33
tests/device_tests.cpp Normal file
View File

@ -0,0 +1,33 @@
#include "doctest/doctest.h"
#include <cstdlib>
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test device placement") {
auto device = default_device();
Device d = metal::is_available() ? Device::gpu : Device::cpu;
if (std::getenv("DEVICE") == nullptr) {
CHECK_EQ(device, d);
}
array x(1.0f);
array y(1.0f);
auto z = add(x, y, default_device());
if (metal::is_available()) {
z = add(x, y, Device::gpu);
z = add(x, y, Device(Device::gpu, 0));
} else {
CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument);
CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument);
}
// Set the default device to the CPU
set_default_device(Device::cpu);
CHECK_EQ(default_device(), Device::cpu);
// Revert
set_default_device(device);
}

97
tests/eval_tests.cpp Normal file
View File

@ -0,0 +1,97 @@
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test eval") {
{
array x(1.0);
array y(1);
array z(true);
eval({x, y, z});
CHECK_EQ(x.item<float>(), 1.0);
}
{
array x(1.0);
array y = ones({2, 2});
array z(true);
eval({x, y, z});
CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item<bool>());
}
}
TEST_CASE("test eval multiple") {
auto x = ones({10, 10});
auto y = ones({10, 10});
eval({x, y});
CHECK(array_equal(x, y).item<bool>());
auto a = x + y;
auto b = x - y;
eval({a, b});
CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
x = ones({10, 10});
y = ones({10, 10});
eval(x, y);
CHECK(array_equal(x, y).item<bool>());
a = x + y;
b = x - y;
eval(a, b);
CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
}
TEST_CASE("test eval with tracer") {
auto x = array(1);
x.set_tracer(true);
// Ok, x is not a node
eval(x);
x = ones({2, 3});
x.set_tracer(true);
CHECK_THROWS(eval(x));
// Ok retain_graph=true
eval({x}, true);
// Make sure all arguments are checked
auto y = ones({2, 3});
CHECK_THROWS(eval(x, y));
}
TEST_CASE("test eval graph retention") {
auto x = array(1);
auto y = array(2);
auto z = x + y;
eval({z}, true);
CHECK(z.has_primitive());
CHECK(z.is_evaled());
CHECK_EQ(z.item<int>(true), 3);
CHECK(z.has_primitive());
CHECK(z.is_evaled());
CHECK_EQ(z.item<int>(), 3);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
z = x + y;
auto a = z + x;
auto b = a + y;
eval({b}, true);
CHECK(z.has_primitive());
CHECK(z.is_evaled());
CHECK(a.has_primitive());
CHECK(a.is_evaled());
eval({b}, false);
CHECK(!z.has_primitive());
CHECK(z.is_evaled());
CHECK(!a.has_primitive());
CHECK(a.is_evaled());
}

81
tests/load_tests.cpp Normal file
View File

@ -0,0 +1,81 @@
#include <filesystem>
#include <stdexcept>
#include <vector>
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name);
}
TEST_CASE("test single array serialization") {
// Basic test
{
auto a = random::uniform(-5.f, 5.f, {2, 5, 12}, float32);
std::string file_path = get_temp_file("test_arr.npy");
save(file_path, a);
auto b = load(file_path);
CHECK_EQ(a.dtype(), b.dtype());
CHECK_EQ(a.shape(), b.shape());
CHECK(array_equal(a, b).item<bool>());
}
// Other shapes
{
auto a = random::uniform(
-5.f,
5.f,
{
1,
},
float32);
std::string file_path = get_temp_file("test_arr_0.npy");
save(file_path, a);
auto b = load(file_path);
CHECK_EQ(a.dtype(), b.dtype());
CHECK_EQ(a.shape(), b.shape());
CHECK(array_equal(a, b).item<bool>());
}
{
auto a = random::uniform(
-5.f,
5.f,
{
46,
},
float32);
std::string file_path = get_temp_file("test_arr_1.npy");
save(file_path, a);
auto b = load(file_path);
CHECK_EQ(a.dtype(), b.dtype());
CHECK_EQ(a.shape(), b.shape());
CHECK(array_equal(a, b).item<bool>());
}
{
auto a = random::uniform(-5.f, 5.f, {5, 2, 1, 3, 4}, float32);
std::string file_path = get_temp_file("test_arr_2.npy");
save(file_path, a);
auto b = load(file_path);
CHECK_EQ(a.dtype(), b.dtype());
CHECK_EQ(a.shape(), b.shape());
CHECK(array_equal(a, b).item<bool>());
}
}

119
tests/scheduler_tests.cpp Normal file
View File

@ -0,0 +1,119 @@
#include "doctest/doctest.h"
#include "mlx/mlx.h"
#include "mlx/scheduler.h"
using namespace mlx::core;
TEST_CASE("test stream management") {
auto s1 = default_stream(default_device());
CHECK_EQ(s1.device, default_device());
auto s2 = new_stream(default_device());
CHECK_EQ(s2.device, default_device());
CHECK_NE(s1, s2);
// Check that default streams have the correct devices
if (metal::is_available()) {
auto s_gpu = default_stream(Device::gpu);
CHECK_EQ(s_gpu.device, Device::gpu);
} else {
CHECK_THROWS_AS(default_stream(Device::gpu), std::invalid_argument);
}
auto s_cpu = default_stream(Device::cpu);
CHECK_EQ(s_cpu.device, Device::cpu);
s_cpu = new_stream(Device::cpu);
CHECK_EQ(s_cpu.device, Device::cpu);
if (metal::is_available()) {
auto s_gpu = new_stream(Device::gpu);
CHECK_EQ(s_gpu.device, Device::gpu);
} else {
CHECK_THROWS_AS(new_stream(Device::gpu), std::invalid_argument);
}
}
TEST_CASE("test asynchronous launch") {
auto s1 = default_stream(default_device());
auto s2 = new_stream(default_device());
// Make sure streams execute asynchronously
int x = 1;
auto p1 = std::make_shared<std::promise<void>>();
auto p2 = std::make_shared<std::promise<void>>();
auto f1 = p1->get_future().share();
auto f2 = p2->get_future().share();
auto fn1 = [&x, p = std::move(p1)]() {
x++;
p->set_value();
};
auto fn2 = [&x, p = std::move(p2), f = std::move(f1)]() {
f.wait();
x *= 5;
p->set_value();
};
// fn2 is launched first and is waiting on fn1 but since
// they are on different streams there is no deadlock.
scheduler::enqueue(s2, std::move(fn2));
scheduler::enqueue(s1, std::move(fn1));
f2.wait();
CHECK_EQ(x, 10);
}
TEST_CASE("test stream placement") {
auto s1 = default_stream(default_device());
auto s2 = new_stream(default_device());
{
// Wait on stream 1
auto p = std::make_shared<std::promise<void>>();
auto f = p->get_future().share();
scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); });
// Do some work on stream 2
auto x = zeros({100}, float32, s2);
auto y = ones({100}, float32, s2);
auto z = add(x, y, s2);
eval(z);
p->set_value();
}
{
// Wait on stream 1
auto p = std::make_shared<std::promise<void>>();
auto f = p->get_future().share();
scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); });
// Do some work on stream 2
auto fn = [&s2](array a) { return add(a, add(a, a, s2), s2); };
auto x = zeros({100}, s2);
// The whole vjp computation should happen
// on the second stream otherwise this will hang.
auto [out, dout] = vjp(fn, x, ones({100}, s2));
// The whole jvp computation should happen on the
// second stream.
std::tie(out, dout) = jvp(fn, x, ones({100}, s2));
eval(out, dout);
p->set_value();
}
}
TEST_CASE("test scheduler races") {
auto x = zeros({1});
auto y = zeros({100});
eval(x, y);
auto a = exp(x);
eval(a);
a = exp(x);
for (int i = 0; i < 10000; ++i) {
y = exp(y);
}
eval(a, y);
}

26
tests/utils_tests.cpp Normal file
View File

@ -0,0 +1,26 @@
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test type promotion") {
for (auto t : {bool_, uint32, int32, int64, float32}) {
auto a = array(0, t);
CHECK_EQ(result_type({a}), t);
std::vector<array> arrs = {array(0, t), array(0, t)};
CHECK_EQ(result_type(arrs), t);
}
{
std::vector<array> arrs = {array(false), array(0, int32)};
CHECK_EQ(result_type(arrs), int32);
}
{
std::vector<array> arrs = {array(0, int32), array(false), array(0.0f)};
CHECK_EQ(result_type(arrs), float32);
}
}