mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
jagrit's commit files
This commit is contained in:
parent
d1f86272a2
commit
e6306cfee9
63
README.md
63
README.md
@ -1,2 +1,61 @@
|
||||
# mlx
|
||||
MLX: An array framework for Apple silicon
|
||||
# MLX
|
||||
|
||||
MLX is an array framework for machine learning specifically targeting Apple
|
||||
Silicon. MLX is designed with inspiration from Jax, PyTorch, ArrayFire.
|
||||
|
||||
[Documentation](https://at.apple.com/mlx)
|
||||
|
||||
## Build
|
||||
|
||||
```
|
||||
mkdir -p build && cd build
|
||||
cmake .. && make -j
|
||||
```
|
||||
|
||||
Run the C++ tests with `make test` (or `./tests/tests` for more detailed output).
|
||||
|
||||
### Python bidings
|
||||
|
||||
To install run:
|
||||
|
||||
`
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install .
|
||||
`
|
||||
|
||||
For developing use an editable install:
|
||||
|
||||
```
|
||||
env CMAKE_BUILD_PARALLEL_LEVEL="" pip install -e .
|
||||
```
|
||||
|
||||
To make sure the install is working run the tests with:
|
||||
|
||||
```
|
||||
python -m unittest discover python/tests
|
||||
```
|
||||
|
||||
|
||||
## Develop
|
||||
|
||||
- Fork and submit pull requests to the repo.
|
||||
|
||||
- Every PR should have passing tests and at least one review.
|
||||
|
||||
- If a change is likely to impact efficiency, run some of the benchmarks before
|
||||
and after the change. Examples of benchmarks can be found in `benchmarks/cpp/`.
|
||||
|
||||
- Install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
||||
This should install hooks for running `black` and `clang-format` to ensure
|
||||
consistent style for C++ and python code.
|
||||
|
||||
You can also run the formatters manually as follows:
|
||||
|
||||
```
|
||||
clang-format -i file.cpp
|
||||
```
|
||||
|
||||
```
|
||||
black file.py
|
||||
```
|
||||
|
||||
or run `pre-commit run --all-files` to check all files in the repo.
|
||||
|
11
benchmarks/cpp/CMakeLists.txt
Normal file
11
benchmarks/cpp/CMakeLists.txt
Normal file
@ -0,0 +1,11 @@
|
||||
function(build_benchmark SRCFILE)
|
||||
get_filename_component(src_name ${SRCFILE} NAME_WE)
|
||||
set(target "${src_name}")
|
||||
add_executable(${target} ${SRCFILE})
|
||||
target_link_libraries(${target} PRIVATE mlx)
|
||||
endfunction(build_benchmark)
|
||||
|
||||
build_benchmark(single_ops.cpp)
|
||||
build_benchmark(irregular_strides.cpp)
|
||||
build_benchmark(compare_devices.cpp)
|
||||
build_benchmark(autograd.cpp)
|
37
benchmarks/cpp/autograd.cpp
Normal file
37
benchmarks/cpp/autograd.cpp
Normal file
@ -0,0 +1,37 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_value_and_grad() {
|
||||
auto x = ones({200, 1000});
|
||||
eval(x);
|
||||
auto fn = [](array x) {
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
x = log(exp(x));
|
||||
}
|
||||
return sum(x);
|
||||
};
|
||||
|
||||
auto grad_fn = grad(fn);
|
||||
auto independent_value_and_grad = [&]() {
|
||||
auto value = fn(x);
|
||||
auto dfdx = grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
};
|
||||
TIME(independent_value_and_grad);
|
||||
|
||||
auto value_and_grad_fn = value_and_grad(fn);
|
||||
auto combined_value_and_grad = [&]() {
|
||||
auto [value, dfdx] = value_and_grad_fn(x);
|
||||
return std::vector<array>{value, dfdx};
|
||||
};
|
||||
TIME(combined_value_and_grad);
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "Benchmarks for " << default_device() << std::endl;
|
||||
time_value_and_grad();
|
||||
}
|
25
benchmarks/cpp/compare_devices.cpp
Normal file
25
benchmarks/cpp/compare_devices.cpp
Normal file
@ -0,0 +1,25 @@
|
||||
#include <iostream>
|
||||
#include "mlx/mlx.h"
|
||||
#include "time_utils.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
void time_add_op() {
|
||||
std::vector<int> sizes(1, 1);
|
||||
for (int i = 0; i < 9; ++i) {
|
||||
sizes.push_back(10 * sizes.back());
|
||||
}
|
||||
set_default_device(Device::cpu);
|
||||
for (auto size : sizes) {
|
||||
auto a = random::uniform({size});
|
||||
auto b = random::uniform({size});
|
||||
eval(a, b);
|
||||
std::cout << "Size " << size << std::endl;
|
||||
TIMEM("cpu", add, a, b, Device::cpu);
|
||||
TIMEM("gpu", add, a, b, Device::gpu);
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
time_add_op();
|
||||
}
|
38
benchmarks/numpy/single_ops.py
Normal file
38
benchmarks/numpy/single_ops.py
Normal file
@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
|
||||
from time_utils import time_fn
|
||||
|
||||
|
||||
def time_add():
|
||||
a = np.ones((100, 100, 10), dtype=np.float32)
|
||||
b = np.ones((100, 100, 10), dtype=np.float32)
|
||||
time_fn(np.add, a, b)
|
||||
|
||||
|
||||
def time_matmul():
|
||||
a = np.random.rand(1000, 500).astype(np.float32)
|
||||
b = np.random.rand(500, 1000).astype(np.float32)
|
||||
time_fn(np.matmul, a, b)
|
||||
|
||||
|
||||
def time_exp():
|
||||
a = np.random.randn(1000, 100).astype(np.float32)
|
||||
time_fn(np.exp, a)
|
||||
|
||||
|
||||
def time_take():
|
||||
a = np.random.rand(10000, 500)
|
||||
ids = np.random.randint(0, 10000, (20, 10))
|
||||
ids = [idx.reshape(-1) for idx in np.split(ids, 20)]
|
||||
|
||||
def random_take():
|
||||
return [np.take(a, idx, 0) for idx in ids]
|
||||
|
||||
time_fn(random_take)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
time_add()
|
||||
time_matmul()
|
||||
time_exp()
|
||||
time_take()
|
18
benchmarks/numpy/time_utils.py
Normal file
18
benchmarks/numpy/time_utils.py
Normal file
@ -0,0 +1,18 @@
|
||||
import time
|
||||
|
||||
|
||||
def time_fn(fn, *args):
|
||||
print(f"Timing {fn.__name__} ...", end=" ")
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
fn(*args)
|
||||
|
||||
num_iters = 100
|
||||
tic = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
x = fn(*args)
|
||||
toc = time.perf_counter()
|
||||
|
||||
msec = 1e3 * (toc - tic) / num_iters
|
||||
print(f"{msec:.5f} msec")
|
190
benchmarks/python/blas/bench_gemm.py
Normal file
190
benchmarks/python/blas/bench_gemm.py
Normal file
@ -0,0 +1,190 @@
|
||||
import numpy as np
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
import time
|
||||
import torch
|
||||
import os
|
||||
import math
|
||||
import subprocess
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 8
|
||||
N_iter_bench = 80
|
||||
N_iter_func = 5
|
||||
|
||||
|
||||
def bench(f, a, b):
|
||||
for i in range(N_warmup):
|
||||
f(a, b)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(a, b)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def gemm_nn_mlx(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a @ b
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
def gemm_nt_mlx(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a @ b.transpose((0, 2, 1))
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
def gemm_tn_mlx(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a.transpose((0, 2, 1)) @ b
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
def gemm_tt_mlx(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1))
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemm_nn_torch(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a @ b
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemm_nt_torch(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a @ b.transpose(-1, -2)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemm_tn_torch(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a.transpose(-1, -2) @ b
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemm_tt_torch(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = a.transpose(-1, -2) @ b.transpose(-1, -2)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
||||
shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M)
|
||||
shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K)
|
||||
|
||||
a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype)
|
||||
b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_pt = torch.from_numpy(a_np).to("mps")
|
||||
b_pt = torch.from_numpy(b_np).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = {
|
||||
"nn": gemm_nn_mlx,
|
||||
"nt": gemm_nt_mlx,
|
||||
"tn": gemm_tn_mlx,
|
||||
"tt": gemm_tt_mlx,
|
||||
}[transpose]
|
||||
|
||||
f_pt = {
|
||||
"nn": gemm_nn_torch,
|
||||
"nt": gemm_nt_torch,
|
||||
"tn": gemm_tn_torch,
|
||||
"tt": gemm_tt_torch,
|
||||
}[transpose]
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1)
|
||||
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
||||
|
||||
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
||||
c_npy = a_np.transpose(t_a).astype(np.float32) @ b_np.transpose(t_b).astype(
|
||||
np.float32
|
||||
)
|
||||
|
||||
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
||||
|
||||
if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol):
|
||||
print(
|
||||
f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
def get_gflop_count(B, M, N, K):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
||||
|
||||
dtypes = ("float32", "float16")
|
||||
transposes = ("nn", "nt", "tn")
|
||||
shapes = (
|
||||
(16, 1024, 1024, 1024),
|
||||
(1, 1024, 1024, 2048),
|
||||
(4, 1024, 1024, 4096),
|
||||
(4, 1024, 4096, 1024),
|
||||
(1, 4096, 4096, 4096),
|
||||
(15, 1023, 1023, 1023),
|
||||
(17, 1025, 1025, 1025),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
for transpose in transposes:
|
||||
for B, M, N, K in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose)
|
||||
|
||||
gflop_count = get_gflop_count(B, M, N, K)
|
||||
gflops_mx = gflop_count / (time_mlx)
|
||||
gflops_pt = gflop_count / (time_torch)
|
||||
diff = gflops_mx / gflops_pt - 1.0
|
||||
|
||||
print(
|
||||
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if gflops_pt >= 2.0 * gflops_mx:
|
||||
print("ATTENTION ^^^^^^^")
|
219
benchmarks/python/blas/bench_gemv.py
Normal file
219
benchmarks/python/blas/bench_gemv.py
Normal file
@ -0,0 +1,219 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import argparse
|
||||
import mlx.core as mx
|
||||
import time
|
||||
import torch
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
results_dir = "./results"
|
||||
|
||||
if not os.path.isdir(results_dir):
|
||||
os.mkdir(results_dir)
|
||||
|
||||
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
||||
device_name = device_name.decode("utf-8").strip("\n")
|
||||
|
||||
N_warmup = 5
|
||||
N_iter_bench = 50
|
||||
N_iter_func = 20
|
||||
|
||||
out_vec_sizes = [128, 512, 2048, 4096]
|
||||
in_vec_sizes = [128, 512, 2048, 4096]
|
||||
|
||||
benchmark_vector_lens = []
|
||||
benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2]
|
||||
benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2]
|
||||
benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2]
|
||||
benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000]
|
||||
|
||||
benchmark_vector_lens.sort()
|
||||
|
||||
|
||||
def bench(f, m, v):
|
||||
for i in range(N_warmup):
|
||||
f(m, v)
|
||||
torch.mps.synchronize()
|
||||
|
||||
s = time.perf_counter_ns()
|
||||
for i in range(N_iter_bench):
|
||||
f(m, v)
|
||||
e = time.perf_counter_ns()
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def gemv_mlx(m, v):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = m @ v
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
def gemv_t_mlx(m, v):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = v @ m
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemv_torch(m, v):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = m @ v
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def gemv_t_torch(m, v):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = v @ m
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
|
||||
|
||||
def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False):
|
||||
shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len)
|
||||
shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1)
|
||||
|
||||
mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype)
|
||||
vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype)
|
||||
mat_mlx = mx.array(mat_npy)
|
||||
vec_mlx = mx.array(vec_npy)
|
||||
mat_trc = torch.from_numpy(mat_npy).to("mps")
|
||||
vec_trc = torch.from_numpy(vec_npy).to("mps")
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
time_torch = (
|
||||
bench(gemv_t_torch, mat_trc, vec_trc)
|
||||
if transpose
|
||||
else bench(gemv_torch, mat_trc, vec_trc)
|
||||
)
|
||||
time_mlx = (
|
||||
bench(gemv_t_mlx, mat_mlx, vec_mlx)
|
||||
if transpose
|
||||
else bench(gemv_mlx, mat_mlx, vec_mlx)
|
||||
)
|
||||
|
||||
c_mlx = (
|
||||
np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx)
|
||||
)
|
||||
c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy)
|
||||
|
||||
if not np.allclose(c_mlx, c_npy, atol=2e-5):
|
||||
print(
|
||||
f"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
|
||||
|
||||
def get_gflop_count(in_vec_len, out_vec_len):
|
||||
return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float(
|
||||
1024**3
|
||||
)
|
||||
|
||||
|
||||
def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
|
||||
n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len
|
||||
item_size = 4 if np_dtype == np.float32 else 2
|
||||
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
||||
|
||||
|
||||
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, tranpose):
|
||||
np_dtype = getattr(np, dtype)
|
||||
mlx_gb_s = []
|
||||
mlx_gflops = []
|
||||
pyt_gb_s = []
|
||||
pyt_gflops = []
|
||||
|
||||
for out_vec_len in out_vector_lens:
|
||||
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
|
||||
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
|
||||
|
||||
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
|
||||
|
||||
mlx_gb_s.append(gbyte_size / time_mlx)
|
||||
pyt_gb_s.append(gbyte_size / time_torch)
|
||||
|
||||
mlx_gflops.append(gflop_count / time_mlx)
|
||||
pyt_gflops.append(gflop_count / time_torch)
|
||||
|
||||
if transpose:
|
||||
title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}"
|
||||
else:
|
||||
title = f"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}"
|
||||
|
||||
ax.plot(out_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
|
||||
ax.plot(out_vector_lens, pyt_gb_s, "tab:red", label="Torch")
|
||||
ax.set_title(title)
|
||||
ax.set(xlabel="out_vector_len", ylabel="Performance (GB/s)")
|
||||
ax.legend()
|
||||
|
||||
|
||||
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, tranpose):
|
||||
np_dtype = getattr(np, dtype)
|
||||
mlx_gb_s = []
|
||||
mlx_gflops = []
|
||||
pyt_gb_s = []
|
||||
pyt_gflops = []
|
||||
|
||||
for in_vec_len in in_vector_lens:
|
||||
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
|
||||
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
|
||||
|
||||
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
|
||||
|
||||
mlx_gb_s.append(gbyte_size / time_mlx)
|
||||
pyt_gb_s.append(gbyte_size / time_torch)
|
||||
|
||||
mlx_gflops.append(gflop_count / time_mlx)
|
||||
pyt_gflops.append(gflop_count / time_torch)
|
||||
|
||||
if transpose:
|
||||
title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])"
|
||||
else:
|
||||
title = f"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )"
|
||||
|
||||
ax.plot(in_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
|
||||
ax.plot(in_vector_lens, pyt_gb_s, "tab:red", label="Torch")
|
||||
ax.set_title(title)
|
||||
ax.set(xlabel="in_vector_len", ylabel="Performance (GB/s)")
|
||||
ax.legend()
|
||||
|
||||
|
||||
for transpose in (False, True):
|
||||
for dtype in ("float32", "float16"):
|
||||
fig, axs = plt.subplots(
|
||||
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
||||
)
|
||||
|
||||
for i, in_vec_len in enumerate(in_vec_sizes):
|
||||
bench_with_in_len(
|
||||
axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose
|
||||
)
|
||||
|
||||
for i, out_vec_len in enumerate(out_vec_sizes):
|
||||
bench_with_out_len(
|
||||
axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose
|
||||
)
|
||||
|
||||
op_name = "gemv_t" if transpose else "gemv"
|
||||
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
||||
fig.savefig(
|
||||
os.path.join(
|
||||
results_dir, f'{device_name.replace(" ", "_")}_{dtype}_{op_name}.pdf'
|
||||
)
|
||||
)
|
||||
plt.close(fig)
|
116
benchmarks/python/llama_mlx_bench.py
Normal file
116
benchmarks/python/llama_mlx_bench.py
Normal file
@ -0,0 +1,116 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.utils
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.rope = nn.RoPE(dims // num_heads, True)
|
||||
self.query_proj = nn.Linear(dims, dims, False)
|
||||
self.key_proj = nn.Linear(dims, dims, False)
|
||||
self.value_proj = nn.Linear(dims, dims, False)
|
||||
self.out_proj = nn.Linear(dims, dims, False)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None, cache=None):
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
queries = mx.transpose(mx.reshape(queries, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
keys = mx.transpose(mx.reshape(keys, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
values = mx.transpose(mx.reshape(values, (B, L, num_heads, -1)), (0, 2, 1, 3))
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scale = mx.array(math.sqrt(1 / queries.shape[-1]), dtype=queries.dtype)
|
||||
scores = (queries * scale) @ mx.transpose(keys, (0, 1, 3, 2))
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = mx.reshape(mx.transpose(scores @ values, (0, 2, 1, 3)), (B, L, -1))
|
||||
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
|
||||
|
||||
class LlamaEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.attention = LlamaAttention(dims, num_heads)
|
||||
|
||||
self.norm1 = nn.RMSNorm(dims)
|
||||
self.norm2 = nn.RMSNorm(dims)
|
||||
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, False)
|
||||
self.linear2 = nn.Linear(dims, mlp_dims, False)
|
||||
self.linear3 = nn.Linear(mlp_dims, dims, False)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
y = self.norm1(x)
|
||||
y, cache = self.attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.norm2(x)
|
||||
a = self.linear1(y)
|
||||
b = self.linear2(y)
|
||||
y = a * mx.sigmoid(a) * b
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
def measure(model, x, cache):
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
mx.eval(y, c)
|
||||
|
||||
start = time.time()
|
||||
rs = []
|
||||
for i in range(5):
|
||||
y, c = model(x, mask=None, cache=cache)
|
||||
rs.append((y, c))
|
||||
mx.eval(rs)
|
||||
end = time.time()
|
||||
|
||||
return (end - start) * 1000 / 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
H = 32
|
||||
D = 4096
|
||||
F = 43 * 256
|
||||
C = 1000
|
||||
mx.set_default_device(mx.gpu)
|
||||
dtype = mx.float16
|
||||
|
||||
layer = LlamaEncoderLayer(D, F, H)
|
||||
layer.update(mlx.utils.tree_map(lambda x: x.astype(dtype), layer.parameters()))
|
||||
k1, k2, k3 = mx.random.split(mx.random.key(0), 3)
|
||||
x = mx.random.normal([1, 1, D], dtype=dtype)
|
||||
cache = [
|
||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
||||
mx.random.normal([1, H, C, D // H], dtype=dtype),
|
||||
]
|
||||
mx.eval(x, cache)
|
||||
|
||||
T = measure(layer, x, cache)
|
||||
|
||||
print("Time per layer per token:", T, "ms")
|
||||
print("Lower bound total time per token:", T * 32, "ms")
|
56
cmake/extension.cmake
Normal file
56
cmake/extension.cmake
Normal file
@ -0,0 +1,56 @@
|
||||
include(CMakeParseArguments)
|
||||
|
||||
###############################################################################
|
||||
# Build metal library
|
||||
#
|
||||
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
||||
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
||||
#
|
||||
# Args:
|
||||
# TARGET: Custom target to be added for the metal library
|
||||
# TITLE: Name of the .metallib
|
||||
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
||||
# SOURCES: List of source files
|
||||
# INCLUDE_DIRS: List of include dirs
|
||||
# DEPS: List of depedency files (like headers)
|
||||
#
|
||||
macro(mlx_build_metallib)
|
||||
# Parse args
|
||||
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
||||
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
||||
cmake_parse_arguments(
|
||||
MTLLIB
|
||||
""
|
||||
"${oneValueArgs}"
|
||||
"${multiValueArgs}"
|
||||
${ARGN}
|
||||
)
|
||||
|
||||
# Set output
|
||||
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
||||
|
||||
# Collect compile options
|
||||
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
||||
|
||||
# Prepare metllib build command
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND xcrun -sdk macosx metal
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS}
|
||||
${MTLLIB_SOURCES}
|
||||
-o ${MTLLIB_BUILD_TARGET}
|
||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||
COMMAND_EXPAND_LISTS
|
||||
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# Add metallib custom target
|
||||
add_custom_target(
|
||||
${MTLLIB_TARGET}
|
||||
DEPENDS
|
||||
${MTLLIB_BUILD_TARGET}
|
||||
)
|
||||
|
||||
endmacro(mlx_build_metallib)
|
0
docs/.nojekyll
Normal file
0
docs/.nojekyll
Normal file
1
docs/index.html
Normal file
1
docs/index.html
Normal file
@ -0,0 +1 @@
|
||||
<meta http-equiv="refresh" content="0; url=./build/html/index.html" />
|
19
docs/src/_templates/nn-module-template.rst
Normal file
19
docs/src/_templates/nn-module-template.rst
Normal file
@ -0,0 +1,19 @@
|
||||
{{ fullname | escape | underline}}
|
||||
|
||||
.. currentmodule:: {{ module }}
|
||||
|
||||
.. autoclass:: {{ objname }}
|
||||
|
||||
{#{% block methods %}
|
||||
|
||||
{% if methods %}
|
||||
.. rubric:: {{ _('Methods') }}
|
||||
|
||||
.. autosummary::
|
||||
{% for item in methods %}
|
||||
{%- if item not in inherited_members and item != '__init__' %}
|
||||
~{{ name }}.{{ item }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% endif %}
|
||||
{% endblock %}#}
|
6
docs/src/cpp/ops.rst
Normal file
6
docs/src/cpp/ops.rst
Normal file
@ -0,0 +1,6 @@
|
||||
.. _cpp_ops:
|
||||
|
||||
Operations
|
||||
==========
|
||||
|
||||
|
948
docs/src/dev/extensions.rst
Normal file
948
docs/src/dev/extensions.rst
Normal file
@ -0,0 +1,948 @@
|
||||
Developer Documentation
|
||||
=======================
|
||||
|
||||
MLX provides a open and flexible backend to which users may add operations
|
||||
and specialized implementations without much hassle. While the library supplies
|
||||
efficient operations that can be used and composed for any number of
|
||||
applications, there may arise cases where new functionalities or highly
|
||||
optimized implementations are needed. For such cases, you may design and
|
||||
implement your own operations that link to and build on top of :mod:`mlx.core`.
|
||||
We will introduce the inner-workings of MLX and go over a simple example to
|
||||
learn the steps involved in adding new operations to MLX with your own CPU
|
||||
and GPU implementations.
|
||||
|
||||
Introducing the Example
|
||||
-----------------------
|
||||
|
||||
Let's say that you would like an operation that takes in two arrays,
|
||||
``x`` and ``y``, scales them both by some coefficents ``alpha`` and ``beta``
|
||||
respectively, and then adds them together to get the result
|
||||
``z = alpha * x + beta * y``. Well, you can very easily do that by just
|
||||
writing out a function as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
This function performs that operation while leaving the implementations and
|
||||
differentiation to MLX.
|
||||
|
||||
However, you work with vector math libraries often and realize that the
|
||||
``axpby`` routine defines the same operation ``Y = (alpha * X) + (beta * Y)``.
|
||||
You would really like the part of your applications that does this operation
|
||||
on the CPU to be very fast - so you decide that you want it to rely on the
|
||||
``axpby`` routine provided by the Accelerate_ framework. Continuing to impose
|
||||
our assumptions on to you, let's also assume that you want to learn how add
|
||||
your own implementation for the gradients of your new operation while going
|
||||
over the ins-and-outs of the MLX framework.
|
||||
|
||||
Well, what a coincidence! You are in the right place. Over the course of this
|
||||
example, we will learn:
|
||||
|
||||
* The structure of the MLX library from the frontend API to the backend implementations.
|
||||
* How to implement your own CPU backend that redirects to Accelerate_ when appropriate (and a fallback if needed).
|
||||
* How to implement your own GPU implementation using metal.
|
||||
* How to add your own ``vjp`` and ``jvp``.
|
||||
* How to build your implementations, link them to MLX, and bind them to python.
|
||||
|
||||
Operations and Primitives
|
||||
-------------------------
|
||||
|
||||
In one sentence, operations in MLX build the computation graph, and primitives
|
||||
provide the rules for evaluation and transformations of said graph. Let's start
|
||||
by discussing operations in more detail.
|
||||
|
||||
Operations
|
||||
^^^^^^^^^^^
|
||||
|
||||
Operations are the frontend functions that operate on arrays. They are defined
|
||||
in the C++ API (:ref:`cpp_ops`) and then we provide bindings to these
|
||||
operations in the Python API (:ref:`ops`).
|
||||
|
||||
We would like an operation, :meth:`axpby` that takes in two arrays ``x`` and ``y``,
|
||||
and two scalars, ``alpha`` and ``beta``. This is how we would define it in the
|
||||
C++ API:
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/**
|
||||
* Scale and sum two vectors elementwise
|
||||
* z = alpha * x + beta * y
|
||||
*
|
||||
* Follow numpy style broadcasting between x and y
|
||||
* Inputs are upcasted to floats if needed
|
||||
**/
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s = {} // Stream on which to schedule the operation
|
||||
);
|
||||
|
||||
|
||||
This operation itself can call other operations within it if needed. So, the
|
||||
simplest way to go about implementing this operation would be do so in terms
|
||||
of existing operations.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Scale x and y on the provided stream
|
||||
auto ax = multiply(array(alpha), x, s);
|
||||
auto by = multiply(array(beta), y, s);
|
||||
|
||||
// Add and return
|
||||
return add(ax, by, s);
|
||||
}
|
||||
|
||||
However, as we discussed earlier, this is not our goal. The operations themselves
|
||||
do not contain the implementations that act on the data, nor do they contain the
|
||||
rules of transformations. Rather, they are an easy to use interface that build
|
||||
on top of the building blocks we call :class:`Primitive`.
|
||||
|
||||
Primitives
|
||||
^^^^^^^^^^^
|
||||
|
||||
A :class:`Primitive` is part of the computation graph of an :class:`array`. It
|
||||
defines how to create an output given a set of input :class:`array` . Further,
|
||||
a :class:`Primitive` is a class that contains rules on how it is evaluated
|
||||
on the CPU or GPU, and how it acts under transformations such as ``vjp`` and
|
||||
``jvp``. These words on their own can be a bit abstract, so lets take a step
|
||||
back and go to our example to give ourselves a more concrete image.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
class Axpby : public Primitive {
|
||||
public:
|
||||
explicit Axpby(Stream stream, float alpha, float beta)
|
||||
: Primitive(stream), alpha_(alpha), beta_(beta){};
|
||||
|
||||
/**
|
||||
* A primitive must know how to evaluate itself on the CPU/GPU
|
||||
* for the given inputs and populate the output array.
|
||||
*
|
||||
* To avoid unecessary allocations, the evaluation function
|
||||
* is responsible for allocating space for the array.
|
||||
*/
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) override;
|
||||
|
||||
/**
|
||||
* The primitive must know how to vectorize itself accross
|
||||
* the given axes. The output is a pair containing the array
|
||||
* representing the vectorized computation and the axis which
|
||||
* corresponds to the output vectorized dimension.
|
||||
*/
|
||||
std::pair<array, int> vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) override;
|
||||
|
||||
/** Print the primitive. */
|
||||
void print(std::ostream& os) override {
|
||||
os << "Axpby";
|
||||
}
|
||||
|
||||
/** Equivalence check **/
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
The :class:`Axpby` class derives from the base :class:`Primitive` class and
|
||||
follows the above demonstrated interface. :class:`Axpby` treats ``alpha`` and
|
||||
``beta`` as parameters. It then provides implementations of how the array ``out``
|
||||
is produced given ``inputs`` through :meth:`Axpby::eval_cpu` and
|
||||
:meth:`Axpby::eval_gpu`. Further, it provides rules of transformations in
|
||||
:meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and :meth:`Axpby::vmap`.
|
||||
|
||||
Using the Primitives
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Operations can use this :class:`Primitive` to add a new :class:`array` to
|
||||
the computation graph. An :class:`array` can be constructed by providing its
|
||||
data type, shape, the :class:`Primitive` that computes it, and the
|
||||
:class:`array` inputs that are passed to the primitive.
|
||||
|
||||
Let's re-implement our operation now in terms of our :class:`Axpby` primitive.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
array axpby(
|
||||
const array& x, // Input array x
|
||||
const array& y, // Input array y
|
||||
const float alpha, // Scaling factor for x
|
||||
const float beta, // Scaling factor for y
|
||||
StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
|
||||
) {
|
||||
// Promote dtypes between x and y as needed
|
||||
auto promoted_dtype = promote_types(x.dtype(), y.dtype());
|
||||
|
||||
// Upcast to float32 for non-floating point inputs x and y
|
||||
auto out_dtype = is_floating_point(promoted_dtype)
|
||||
? promoted_dtype
|
||||
: promote_types(promoted_dtype, float32);
|
||||
|
||||
// Cast x and y up to the determined dtype (on the same stream s)
|
||||
auto x_casted = astype(x, out_dtype, s);
|
||||
auto y_casted = astype(y, out_dtype, s);
|
||||
|
||||
// Broadcast the shapes of x and y (on the same stream s)
|
||||
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
|
||||
auto out_shape = broadcasted_inputs[0].shape();
|
||||
|
||||
// Construct the array as the output of the Axpby primitive
|
||||
// with the broadcasted and upcasted arrays as inputs
|
||||
return array(
|
||||
/* const std::vector<int>& shape = */ out_shape,
|
||||
/* Dtype dtype = */ out_dtype,
|
||||
/* std::unique_ptr<Primitive> primitive = */
|
||||
std::make_unique<Axpby>(to_stream(s), alpha, beta),
|
||||
/* const std::vector<array>& inputs = */ broadcasted_inputs);
|
||||
}
|
||||
|
||||
|
||||
This operation now handles the following:
|
||||
|
||||
#. Upcast inputs and resolve the the output data type.
|
||||
#. Broadcast the inputs and resolve the output shape.
|
||||
#. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
|
||||
#. Construct the output :class:`array` using the primitive and the inputs.
|
||||
|
||||
Implementing the Primitive
|
||||
--------------------------
|
||||
|
||||
No computation happens when we call the operation alone. In effect, the
|
||||
operation only builds the computation graph. When we evaluate the output
|
||||
array, MLX schedules the execution of the computation graph, and calls
|
||||
:meth:`Axpby::eval_cpu` or :meth:`Axpby::eval_gpu` depending on the
|
||||
stream/device specified by the user.
|
||||
|
||||
.. warning::
|
||||
When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
|
||||
no memory has been allocated for the output array. It falls on the implementation
|
||||
of these functions to allocate memory as needed
|
||||
|
||||
Implementing the CPU Backend
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's start by trying to implement a naive and generic version of
|
||||
:meth:`Axpby::eval_cpu`. We declared this as a private member function of
|
||||
:class:`Axpby` earlier called :meth:`Axpby::eval`.
|
||||
|
||||
Our naive method will go over each element of the output array, find the
|
||||
corresponding input elements of ``x`` and ``y`` and perform the operation
|
||||
pointwise. This is captured in the templated function :meth:`axpby_impl`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// We only allocate memory when we are ready to fill the output
|
||||
// malloc_or_wait synchronously allocates available memory
|
||||
// There may be a wait executed here if the allocation is requested
|
||||
// under memory-pressured conditions
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Collect input and output data pointers
|
||||
const T* x_ptr = x.data<T>();
|
||||
const T* y_ptr = y.data<T>();
|
||||
T* out_ptr = out.data<T>();
|
||||
|
||||
// Cast alpha and beta to the relevant types
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Do the elementwise operation for each output
|
||||
for (size_t out_idx = 0; out_idx < out.size(); out_idx++) {
|
||||
// Map linear indices to offsets in x and y
|
||||
auto x_offset = elem_to_loc(out_idx, x.shape(), x.strides());
|
||||
auto y_offset = elem_to_loc(out_idx, y.shape(), y.strides());
|
||||
|
||||
// We allocate the output to be contiguous and regularly strided
|
||||
// (defaults to row major) and hence it doesn't need additonal mapping
|
||||
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
|
||||
}
|
||||
}
|
||||
|
||||
Now, we would like our implementation to be able to do this pointwise operation
|
||||
for all incoming floating point arrays. Accordingly, we add dispatches for
|
||||
``float32``, ``float16``, ``bfloat16`` and ``complex64``. We throw an error
|
||||
if we encounter an unexpected type.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Fall back implementation for evaluation on CPU */
|
||||
void Axpby::eval(const std::vector<array>& inputs, array& out) {
|
||||
// Check the inputs (registered in the op while contructing the out array)
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
|
||||
// Dispatch to the correct dtype
|
||||
if (out.dtype() == float32) {
|
||||
return axpby_impl<float>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == float16) {
|
||||
return axpby_impl<float16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == bfloat16) {
|
||||
return axpby_impl<bfloat16_t>(x, y, out, alpha_, beta_);
|
||||
} else if (out.dtype() == complex64) {
|
||||
return axpby_impl<complex64_t>(x, y, out, alpha_, beta_);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Axpby is only supported for floating point types.");
|
||||
}
|
||||
}
|
||||
|
||||
We have a fallback implementation! Now, to do what we are really here to do.
|
||||
Remember we wanted to use the ``axpby`` routine provided by the Accelerate_
|
||||
framework? Well, there are 3 complications to keep in mind:
|
||||
|
||||
#. Accelerate does not provide implementations of ``axpby`` for half precision
|
||||
floats. We can only direct to it for ``float32`` types
|
||||
#. Accelerate assumes the inputs ``x`` and ``y`` are contiguous and all elements
|
||||
have fixed strides between them. Possibly due to broadcasts and transposes,
|
||||
we aren't guaranteed that the inputs fit this requirement. We can
|
||||
only direct to Accelerate if both ``x`` and ``y`` are row contiguous or
|
||||
column contiguous.
|
||||
#. Accelerate performs the routine ``Y = (alpha * X) + (beta * Y)`` inplace.
|
||||
MLX expects to write out the answer to a new array. We must copy the elements
|
||||
of ``y`` into the output array and use that as an input to ``axpby``
|
||||
|
||||
Let's write out an implementation that uses Accelerate in the right conditions.
|
||||
It must simply allocate data for the output, copy elements of ``y`` into it,
|
||||
and then call the :meth:`catlas_saxpby` from accelerate.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
void axpby_impl_accelerate(
|
||||
const array& x,
|
||||
const array& y,
|
||||
array& out,
|
||||
float alpha_,
|
||||
float beta_) {
|
||||
// Accelerate library provides catlas_saxpby which does
|
||||
// Y = (alpha * X) + (beta * Y) in place
|
||||
// To use it, we first copy the data in y over to the output array
|
||||
|
||||
// This specialization requires both x and y be contiguous in the same mode
|
||||
// i.e: corresponding linear indices in both point to corresponding elements
|
||||
// The data in the output array is allocated to match the strides in y
|
||||
// such that x, y, and out are contiguous in the same mode and
|
||||
// no transposition is needed
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(y.data_size() * out.itemsize()),
|
||||
y.data_size(),
|
||||
y.strides(),
|
||||
y.flags());
|
||||
|
||||
// We then copy over the elements using the contiguous vector specialization
|
||||
copy_inplace(y, out, CopyType::Vector);
|
||||
|
||||
// Get x and y pointers for catlas_saxpby
|
||||
const T* x_ptr = x.data<T>();
|
||||
T* y_ptr = out.data<T>();
|
||||
|
||||
T alpha = static_cast<T>(alpha_);
|
||||
T beta = static_cast<T>(beta_);
|
||||
|
||||
// Call the inplace accelerate operator
|
||||
catlas_saxpby(
|
||||
/* N = */ out.size(),
|
||||
/* ALPHA = */ alpha,
|
||||
/* X = */ x_ptr,
|
||||
/* INCX = */ 1,
|
||||
/* BETA = */ beta,
|
||||
/* Y = */ y_ptr,
|
||||
/* INCY = */ 1);
|
||||
}
|
||||
|
||||
Great! But what about the inputs that do not fit the criteria for accelerate?
|
||||
Luckily, we can always just direct back to :meth:`Axpby::eval`.
|
||||
|
||||
With this in mind, lets finally implement our :meth:`Axpby::eval_cpu`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on CPU using accelerate specializations */
|
||||
void Axpby::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
|
||||
// Accelerate specialization for contiguous single precision float arrays
|
||||
if (out.dtype() == float32 &&
|
||||
((x.flags().row_contiguous && y.flags().row_contiguous) ||
|
||||
(x.flags().col_contiguous && y.flags().col_contiguous))) {
|
||||
axpby_impl_accelerate<float>(x, y, out, alpha_, beta_);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to common backend if specializations are not available
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
We have now hit a milestone! Just this much is enough to run the operation
|
||||
:meth:`axpby` on a CPU stream!
|
||||
|
||||
If you do not plan on running the operation on the GPU or using transforms on
|
||||
computation graphs that contain :class:`Axpby`, you can stop implementing the
|
||||
primitive here and enjoy the speed-ups you get from the Accelerate library.
|
||||
|
||||
Implementing the GPU Backend
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Apple silicon devices address their GPUs using the Metal_ shading language, and
|
||||
all GPU kernels in MLX are written using metal.
|
||||
|
||||
.. note::
|
||||
|
||||
Here are some helpful resources if you are new to metal!
|
||||
|
||||
* A walkthrough of the metal compute pipeline: `Metal Example`_
|
||||
* Documentation for metal shading language: `Metal Specification`_
|
||||
* Using metal from C++: `Metal-cpp`_
|
||||
|
||||
Let's keep the GPU algorithm simple. We will launch exactly as many threads
|
||||
as there are elements in the output. Each thread will pick the element it needs
|
||||
from ``x`` and ``y``, do the pointwise operation, and then update its assigned
|
||||
element in the output.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
template <typename T>
|
||||
[[kernel]] void axpby_general(
|
||||
device const T* x [[buffer(0)]],
|
||||
device const T* y [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
constant const float& alpha [[buffer(3)]],
|
||||
constant const float& beta [[buffer(4)]],
|
||||
constant const int* shape [[buffer(5)]],
|
||||
constant const size_t* x_strides [[buffer(6)]],
|
||||
constant const size_t* y_strides [[buffer(7)]],
|
||||
constant const int& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
// Convert linear indices to offsets in array
|
||||
auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
|
||||
auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
|
||||
|
||||
// Do the operation and update the output
|
||||
out[index] =
|
||||
static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
|
||||
}
|
||||
|
||||
We then need to instantiate this template for all floating point types and give
|
||||
each instantiation a unique host name so we can identify the right kernel for
|
||||
each data type.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
#define instantiate_axpby(type_name, type) \
|
||||
template [[host_name("axpby_general_" #type_name)]] \
|
||||
[[kernel]] void axpby_general<type>( \
|
||||
device const type* x [[buffer(0)]], \
|
||||
device const type* y [[buffer(1)]], \
|
||||
device type* out [[buffer(2)]], \
|
||||
constant const float& alpha [[buffer(3)]], \
|
||||
constant const float& beta [[buffer(4)]], \
|
||||
constant const int* shape [[buffer(5)]], \
|
||||
constant const size_t* x_strides [[buffer(6)]], \
|
||||
constant const size_t* y_strides [[buffer(7)]], \
|
||||
constant const int& ndim [[buffer(8)]], \
|
||||
uint index [[thread_position_in_grid]]);
|
||||
|
||||
instantiate_axpby(float32, float);
|
||||
instantiate_axpby(float16, half);
|
||||
instantiate_axpby(bflot16, bfloat16_t);
|
||||
instantiate_axpby(complex64, complex64_t);
|
||||
|
||||
This kernel will be compiled into a metal library ``mlx_ext.metallib`` as we
|
||||
will see later in :ref:`Building with CMake`. In the following example, we
|
||||
assume that the library ``mlx_ext.metallib`` will always be co-located with
|
||||
the executable/ shared-library calling the :meth:`register_library` function.
|
||||
The :meth:`register_library` function takes the library's name and potential
|
||||
path (or in this case, a function that can produce the path of the metal
|
||||
library) and tries to load that library if it hasn't already been registered
|
||||
by the relevant static :class:`mlx::core::metal::Device` object. This is why,
|
||||
it is important to package your C++ library with the metal library. We will
|
||||
go over this process in more detail later.
|
||||
|
||||
The logic to determine the kernel, set the inputs, resolve the grid dimensions
|
||||
and dispatch it to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
|
||||
below.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Evaluate primitive on GPU */
|
||||
void Axpby::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Prepare inputs
|
||||
assert(inputs.size() == 2);
|
||||
auto& x = inputs[0];
|
||||
auto& y = inputs[1];
|
||||
|
||||
// Each primitive carries the stream it should execute on
|
||||
// and each stream carries its device identifiers
|
||||
auto& s = stream();
|
||||
// We get the needed metal device using the stream
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Allocate output memory
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Resolve name of kernel (corresponds to axpby.metal)
|
||||
std::ostringstream kname;
|
||||
kname << "axpby_" << "general_" << type_to_name(out);
|
||||
|
||||
// Make sure the metal library is available and look for it
|
||||
// in the same folder as this executable if needed
|
||||
d.register_library("mlx_ext", metal::get_colocated_mtllib_path);
|
||||
|
||||
// Make a kernel from this metal library
|
||||
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel decelaration at axpby.metal
|
||||
int ndim = out.ndim();
|
||||
size_t nelem = out.size();
|
||||
|
||||
// Encode input arrays to kernel
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, y, 1);
|
||||
|
||||
// Encode output arrays to kernel
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
|
||||
|
||||
// Fix the 3D size of each threadgroup (in terms of threads)
|
||||
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
|
||||
|
||||
// Fix the 3D size of the launch grid (in terms of threads)
|
||||
MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
|
||||
|
||||
// Launch the grid with the given number of threads divded among
|
||||
// the given threadgroups
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
||||
A few things to note about MLX and metal before moving on. MLX keeps track
|
||||
of the active ``compute_encoder``. We rely on :meth:`d.get_command_encoder`
|
||||
to give us the active metal compute command encoder instead of building a
|
||||
new one and calling :meth:`compute_encoder->end_encoding` at the end.
|
||||
MLX keeps adding kernels (compute pipelines) to the active command encoder
|
||||
until some specified limit is hit or the compute encoder needs to be flushed
|
||||
for synchronization. MLX also handles enqueuing and commiting the associated
|
||||
command buffers as needed. We suggest taking a deeper dive into
|
||||
:class:`metal::Device` if you would like to study this routine further.
|
||||
|
||||
Primitive Transforms
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Now that we have come this far, let's also learn how to add implementations to
|
||||
transformations in a :class:`Primitive`. These transformations can be built on
|
||||
top of our operations, including the one we just defined now. Which then gives
|
||||
us the following :meth:`Axpby::jvp` and :meth:`Axpby::vjp` implementations.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** The Jacobian-vector product. */
|
||||
array Axpby::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
// Forward mode diff that pushes along the tangents
|
||||
// The jvp transform on the the primitive can built with ops
|
||||
// that are scheduled on the same stream as the primtive
|
||||
|
||||
// If argnums = {0}, we only push along x in which case the
|
||||
// jvp is just the tangent scaled by alpha
|
||||
// Similarly, if argnums = {1}, the jvp is just the tangent
|
||||
// scaled by beta
|
||||
if (argnums.size() > 1) {
|
||||
auto scale = argnums[0] == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, tangents[0].dtype());
|
||||
return multiply(scale_arr, tangents[0], stream());
|
||||
}
|
||||
// If, argnums = {0, 1}, we take contributions from both
|
||||
// which gives us jvp = tangent_x * alpha + tangent_y * beta
|
||||
else {
|
||||
return axpby(tangents[0], tangents[1], alpha_, beta_, stream());
|
||||
}
|
||||
}
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** The vector-Jacobian product. */
|
||||
std::vector<array> Axpby::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const array& cotan,
|
||||
const std::vector<int>& argnums) {
|
||||
// Reverse mode diff
|
||||
std::vector<array> vjps;
|
||||
for (auto arg : argnums) {
|
||||
auto scale = arg == 0 ? alpha_ : beta_;
|
||||
auto scale_arr = array(scale, cotan.dtype());
|
||||
vjps.push_back(multiply(scale_arr, cotan, stream()));
|
||||
}
|
||||
return vjps;
|
||||
}
|
||||
|
||||
Finally, you need not have a transformation fully defined to start using your
|
||||
own :class:`Primitive`.
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
/** Vectorize primitve along given axis */
|
||||
std::pair<array, int> Axpby::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::runtime_error("Axpby has no vmap implementation.");
|
||||
}
|
||||
|
||||
Building and Binding
|
||||
--------------------
|
||||
|
||||
Let's look at the overall directory structure first.
|
||||
|
||||
| extensions
|
||||
| ├── axpby
|
||||
| │ ├── axpby.cpp
|
||||
| │ ├── axpby.h
|
||||
| │ └── axpby.metal
|
||||
| ├── mlx_sample_extensions
|
||||
| │ └── __init__.py
|
||||
| ├── bindings.cpp
|
||||
| ├── CMakeLists.txt
|
||||
| └── setup.py
|
||||
|
||||
* ``extensions/axpby/`` defines the C++ extension library
|
||||
* ``extensions/mlx_sample_extensions`` sets out the strucutre for the
|
||||
associated python package
|
||||
* ``extensions/bindings.cpp`` provides python bindings for our operation
|
||||
* ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
|
||||
python bindings
|
||||
* ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
|
||||
the python package
|
||||
|
||||
Binding to Python
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
We use PyBind11_ to build a Python API for the C++ library. Since bindings
|
||||
for all needed components such as `mlx.core.array`, `mlx.core.stream`, etc.
|
||||
are already provided, adding our :meth:`axpby` becomes very simple!
|
||||
|
||||
.. code-block:: C++
|
||||
|
||||
PYBIND11_MODULE(mlx_sample_extensions, m) {
|
||||
m.doc() = "Sample C++ and metal extensions for MLX";
|
||||
|
||||
m.def(
|
||||
"axpby",
|
||||
&axpby,
|
||||
"x"_a,
|
||||
"y"_a,
|
||||
py::pos_only(),
|
||||
"alpha"_a,
|
||||
"beta"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = py::none(),
|
||||
R"pbdoc(
|
||||
Scale and sum two vectors elementwise
|
||||
``z = alpha * x + beta * y``
|
||||
|
||||
Follows numpy style broadcasting between ``x`` and ``y``
|
||||
Inputs are upcasted to floats if needed
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
y (array): Input array.
|
||||
alpha (float): Scaling factor for ``x``.
|
||||
beta (float): Scaling factor for ``y``.
|
||||
|
||||
Returns:
|
||||
array: ``alpha * x + beta * y``
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
Most of the complexity in the above example comes from additional bells and
|
||||
whistles such as the literal names and doc-strings.
|
||||
|
||||
.. warning::
|
||||
|
||||
:mod:`mlx.core` needs to be imported before importing
|
||||
:mod:`mlx_sample_extensions` as defined by the pybind11 module above to
|
||||
ensure that the casters for :mod:`mlx.core` components like
|
||||
:class:`mlx.core.array` are available.
|
||||
|
||||
.. _Building with CMake:
|
||||
|
||||
Building with CMake
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Building the C++ extension library itself is simple, it only requires that you
|
||||
``find_package(MLX CONFIG)`` and then link it to your library.
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
# Add library
|
||||
add_library(mlx_ext)
|
||||
|
||||
# Add sources
|
||||
target_sources(
|
||||
mlx_ext
|
||||
PUBLIC
|
||||
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
|
||||
)
|
||||
|
||||
# Add include headers
|
||||
target_include_directories(
|
||||
mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
# Link to mlx
|
||||
target_link_libraries(mlx_ext PUBLIC mlx)
|
||||
|
||||
We also need to build the attached metal library. For convenience, we provide a
|
||||
:meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
|
||||
sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
|
||||
automatically imported with MLX package).
|
||||
|
||||
Here is what that looks like in practice!
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
# Build metallib
|
||||
if(MLX_BUILD_METAL)
|
||||
|
||||
mlx_build_metallib(
|
||||
TARGET mlx_ext_metallib
|
||||
TITLE mlx_ext
|
||||
SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
|
||||
INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
|
||||
OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
add_dependencies(
|
||||
mlx_ext
|
||||
mlx_ext_metallib
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
Finally, we build the Pybind11_ bindings
|
||||
|
||||
.. code-block:: cmake
|
||||
|
||||
pybind11_add_module(
|
||||
mlx_sample_extensions
|
||||
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
|
||||
)
|
||||
target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path)
|
||||
endif()
|
||||
|
||||
Building with ``setuptools``
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Once we have set out the CMake build rules as described above, we can use the
|
||||
build utilities defined in :mod:`mlx.extension` for a simple build process.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from mlx import extension
|
||||
from setuptools import setup
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup(
|
||||
name="mlx_sample_extensions",
|
||||
version="0.0.0",
|
||||
description="Sample C++ and Metal extensions for MLX primitives.",
|
||||
ext_modules=[extension.CMakeExtension("mlx_sample_extensions")],
|
||||
cmdclass={"build_ext": extension.CMakeBuild},
|
||||
packages = ["mlx_sample_extensions"],
|
||||
package_dir = {"": "mlx_sample_extensions"},
|
||||
package_data = {"mlx_sample_extensions" : ["*.so", "*.dylib", "*.metallib"]},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.7",
|
||||
)
|
||||
|
||||
.. note::
|
||||
We treat ``extensions/mlx_sample_extensions`` as the package directory
|
||||
even though it only contains a ``__init__.py`` to ensure the following:
|
||||
|
||||
* :mod:`mlx.core` is always imported before importing :mod:`mlx_sample_extensions`
|
||||
* The C++ extension library and the metal library are co-located with the python
|
||||
bindings and copied together if the package is installed
|
||||
|
||||
You can build inplace for development using
|
||||
``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
|
||||
|
||||
This will result in a directory structure as follows:
|
||||
|
||||
| extensions
|
||||
| ├── mlx_sample_extensions
|
||||
| │ ├── __init__.py
|
||||
| │ ├── libmlx_ext.dylib # C++ extension library
|
||||
| │ ├── mlx_ext.metallib # Metal library
|
||||
| │ └── mlx_sample_extensions.cpython-3x-darwin.so # Python Binding
|
||||
| ...
|
||||
|
||||
When you try to install using the command ``python -m pip install .``
|
||||
(in ``extensions/``), the package will be installed with the same strucutre as
|
||||
``extensions/mlx_sample_extensions`` and the C++ and metal library will be
|
||||
copied along with the python binding since they are specified as ``package_data``.
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
After installing the extension as described above, you should be able to simply
|
||||
import the python package and play with it as you would any other MLX operation!
|
||||
|
||||
Let's looks at a simple script and it's results!
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
|
||||
a = mx.ones((3, 4))
|
||||
b = mx.ones((3, 4))
|
||||
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
|
||||
|
||||
print(f"c shape: {c.shape}")
|
||||
print(f"c dtype: {c.dtype}")
|
||||
print(f"c correctness: {mx.all(c == 6.0).item()}")
|
||||
|
||||
Output:
|
||||
|
||||
.. code-block::
|
||||
|
||||
c shape: [3, 4]
|
||||
c dtype: float32
|
||||
c correctness: True
|
||||
|
||||
Results
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Let's run a quick benchmark and see how our new ``axpby`` operation compares
|
||||
with the naive :meth:`simple_axpby` we defined at first on the CPU.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_sample_extensions import axpby
|
||||
import time
|
||||
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
|
||||
return alpha * x + beta * y
|
||||
|
||||
M = 256
|
||||
N = 512
|
||||
|
||||
x = mx.random.normal((M, N))
|
||||
y = mx.random.normal((M, N))
|
||||
alpha = 4.0
|
||||
beta = 2.0
|
||||
|
||||
mx.eval((x, y))
|
||||
|
||||
def bench(f):
|
||||
# Warm up
|
||||
for i in range(100):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
|
||||
# Timed run
|
||||
s = time.time()
|
||||
for i in range(5000):
|
||||
z = f(x, y, alpha, beta)
|
||||
mx.eval(z)
|
||||
e = time.time()
|
||||
return e - s
|
||||
|
||||
simple_time = bench(simple_axpby)
|
||||
custom_time = bench(axpby)
|
||||
|
||||
print(f"Simple axpby: {simple_time:.3f} s | Custom axpby: {custom_time:.3f} s")
|
||||
|
||||
Results:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Simple axpby: 0.114 s | Custom axpby: 0.109 s
|
||||
|
||||
We see some modest improvements right away!
|
||||
|
||||
This operation is now good to be used to build other operations,
|
||||
in :class:`mlx.nn.Module` calls, and also as a part of graph
|
||||
transformations such as :meth:`grad` and :meth:`simplify`!
|
||||
|
||||
Scripts
|
||||
-------
|
||||
|
||||
.. admonition:: Download the code
|
||||
|
||||
The full example code is available in `mlx-examples <code>`_.
|
||||
|
||||
.. code: `TODO_LINK/extensions`_
|
||||
|
||||
.. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
|
||||
.. _Metal: https://developer.apple.com/documentation/metal?language=objc
|
||||
.. _Metal-cpp: https://developer.apple.com/metal/cpp/
|
||||
.. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
.. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
|
||||
.. _PyBind11: https://pybind11.readthedocs.io/en/stable/
|
77
docs/src/examples/linear_regression.rst
Normal file
77
docs/src/examples/linear_regression.rst
Normal file
@ -0,0 +1,77 @@
|
||||
.. _linear_regression:
|
||||
|
||||
Linear Regression
|
||||
-----------------
|
||||
|
||||
Let's implement a basic linear regression model as a starting point to
|
||||
learn MLX. First import the core package and setup some problem metadata:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
num_features = 100
|
||||
num_examples = 1_000
|
||||
num_iters = 10_000 # iterations of SGD
|
||||
lr = 0.01 # learning rate for SGD
|
||||
|
||||
|
||||
We'll generate a synthetic dataset by:
|
||||
|
||||
1. Sampling the design matrix ``X``.
|
||||
2. Sampling a ground truth parameter vector ``w_star``.
|
||||
3. Compute the dependent values ``y`` by adding Gaussian noise to ``X @ w_star``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# True parameters
|
||||
w_star = mx.random.normal((num_features,))
|
||||
|
||||
# Input examples (design matrix)
|
||||
X = mx.random.normal((num_examples, num_features))
|
||||
|
||||
# Noisy labels
|
||||
eps = 1e-2 * mx.random.normal((num_examples,))
|
||||
y = X @ w_star + eps
|
||||
|
||||
|
||||
We will use SGD to find the optimal weights. To start, define the squared loss
|
||||
and get the gradient function of the loss with respect to the parameters.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def loss_fn(w):
|
||||
return 0.5 * mx.mean(mx.square(X @ w - y))
|
||||
|
||||
grad_fn = mx.grad(loss_fn)
|
||||
|
||||
Start the optimization by initializing the parameters ``w`` randomly. Then
|
||||
repeatedly update the parameters for ``num_iters`` iterations.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
w = 1e-2 * mx.random.normal((num_features,))
|
||||
|
||||
for _ in range(num_iters):
|
||||
grad = grad_fn(w)
|
||||
w = w - lr * grad
|
||||
mx.eval(w)
|
||||
|
||||
Finally, compute the loss of the learned parameters and verify that they are
|
||||
close to the ground truth parameters.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
loss = loss_fn(w)
|
||||
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
|
||||
|
||||
print(
|
||||
f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
|
||||
)
|
||||
# Should print something close to: Loss 0.00005, |w-w*| = 0.00364
|
||||
|
||||
Complete `linear regression
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/python/linear_regression.py>`_
|
||||
and `logistic regression
|
||||
<https://github.com/ml-explore/mlx/tree/main/examples/python/logistic_regression.py>`_
|
||||
examples are available in the MLX GitHub repo.
|
52
docs/src/python/data_types.rst
Normal file
52
docs/src/python/data_types.rst
Normal file
@ -0,0 +1,52 @@
|
||||
.. _data_types:
|
||||
|
||||
:orphan:
|
||||
|
||||
Data Types
|
||||
==========
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
The default floating point type is ``float32`` and the default integer type is
|
||||
``int32``. The table below shows supported values for :obj:`Dtype`.
|
||||
|
||||
.. list-table:: Supported Data Types
|
||||
:widths: 5 3 20
|
||||
:header-rows: 1
|
||||
|
||||
* - Type
|
||||
- Bytes
|
||||
- Description
|
||||
* - ``bool_``
|
||||
- 1
|
||||
- Boolean (``True``, ``False``) data type
|
||||
* - ``uint8``
|
||||
- 1
|
||||
- 8-bit unsigned integer
|
||||
* - ``uint16``
|
||||
- 2
|
||||
- 16-bit unsigned integer
|
||||
* - ``uint32``
|
||||
- 4
|
||||
- 32-bit unsigned integer
|
||||
* - ``uint32``
|
||||
- 8
|
||||
- 32-bit unsigned integer
|
||||
* - ``int8``
|
||||
- 1
|
||||
- 8-bit signed integer
|
||||
* - ``int16``
|
||||
- 2
|
||||
- 16-bit signed integer
|
||||
* - ``int32``
|
||||
- 4
|
||||
- 32-bit signed integer
|
||||
* - ``int64``
|
||||
- 8
|
||||
- 64-bit signed integer
|
||||
* - ``float16``
|
||||
- 2
|
||||
- 16-bit float, only available with `ARM C language extensions <https://developer.arm.com/documentation/101028/0012/3--C-language-extensions?lang=en>`_
|
||||
* - ``float32``
|
||||
- 4
|
||||
- 32-bit float
|
17
docs/src/python/devices_and_streams.rst
Normal file
17
docs/src/python/devices_and_streams.rst
Normal file
@ -0,0 +1,17 @@
|
||||
.. _devices_and_streams:
|
||||
|
||||
Devices and Streams
|
||||
===================
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Device
|
||||
default_device
|
||||
set_default_device
|
||||
Stream
|
||||
default_stream
|
||||
new_stream
|
||||
set_default_stream
|
16
docs/src/python/transforms.rst
Normal file
16
docs/src/python/transforms.rst
Normal file
@ -0,0 +1,16 @@
|
||||
.. _transforms:
|
||||
|
||||
Transforms
|
||||
==========
|
||||
|
||||
.. currentmodule:: mlx.core
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
eval
|
||||
grad
|
||||
value_and_grad
|
||||
jvp
|
||||
vjp
|
||||
vmap
|
21
docs/src/python/tree_utils.rst
Normal file
21
docs/src/python/tree_utils.rst
Normal file
@ -0,0 +1,21 @@
|
||||
.. _utils:
|
||||
|
||||
Tree Utils
|
||||
==========
|
||||
|
||||
In MLX we consider a python tree to be an arbitrarily nested collection of
|
||||
dictionaries, lists and tuples without cycles. Functions in this module that
|
||||
return python trees will be using the default python ``dict``, ``list`` and
|
||||
``tuple`` but they can usually process objects that inherit from any of these.
|
||||
|
||||
.. note::
|
||||
Dictionaries should have keys that are valid python identifiers.
|
||||
|
||||
.. currentmodule:: mlx.utils
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
tree_flatten
|
||||
tree_unflatten
|
||||
tree_map
|
10
examples/cpp/CMakeLists.txt
Normal file
10
examples/cpp/CMakeLists.txt
Normal file
@ -0,0 +1,10 @@
|
||||
function(build_example SRCFILE)
|
||||
get_filename_component(src_name ${SRCFILE} NAME_WE)
|
||||
set(target "${src_name}")
|
||||
add_executable(${target} ${SRCFILE})
|
||||
target_link_libraries(${target} PRIVATE mlx)
|
||||
endfunction(build_example)
|
||||
|
||||
build_example(tutorial.cpp)
|
||||
build_example(linear_regression.cpp)
|
||||
build_example(logistic_regression.cpp)
|
52
examples/cpp/linear_regression.cpp
Normal file
52
examples/cpp/linear_regression.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "timer.h"
|
||||
|
||||
/**
|
||||
* An example of linear regression with MLX.
|
||||
*/
|
||||
using namespace mlx::core;
|
||||
|
||||
int main() {
|
||||
int num_features = 100;
|
||||
int num_examples = 1'000;
|
||||
int num_iters = 10'000;
|
||||
float learning_rate = 0.01;
|
||||
|
||||
// True parameters
|
||||
auto w_star = random::normal({num_features});
|
||||
|
||||
// The input examples (design matrix)
|
||||
auto X = random::normal({num_examples, num_features});
|
||||
|
||||
// Noisy labels
|
||||
auto eps = 1e-2 * random::normal({num_examples});
|
||||
auto y = matmul(X, w_star) + eps;
|
||||
|
||||
// Initialize random parameters
|
||||
array w = 1e-2 * random::normal({num_features});
|
||||
|
||||
auto loss_fn = [&](array w) {
|
||||
auto yhat = matmul(X, w);
|
||||
return (0.5f / num_examples) * sum(square(yhat - y));
|
||||
};
|
||||
|
||||
auto grad_fn = grad(loss_fn);
|
||||
|
||||
auto tic = timer::time();
|
||||
for (int it = 0; it < num_iters; ++it) {
|
||||
auto grad = grad_fn(w);
|
||||
w = w - learning_rate * grad;
|
||||
eval(w);
|
||||
}
|
||||
auto toc = timer::time();
|
||||
|
||||
auto loss = loss_fn(w);
|
||||
auto error_norm = std::sqrt(sum(square(w - w_star)).item<float>());
|
||||
auto throughput = num_iters / timer::seconds(toc - tic);
|
||||
std::cout << "Loss " << loss << ", |w - w*| = " << error_norm
|
||||
<< ", Throughput " << throughput << " (it/s)." << std::endl;
|
||||
}
|
3581
mlx/3rdparty/pocketfft.h
vendored
Normal file
3581
mlx/3rdparty/pocketfft.h
vendored
Normal file
File diff suppressed because it is too large
Load Diff
64
mlx/allocator.h
Normal file
64
mlx/allocator.h
Normal file
@ -0,0 +1,64 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
namespace mlx::core::allocator {
|
||||
|
||||
// Simple wrapper around buffer pointers
|
||||
// WARNING: Only Buffer objects constructed from and those that wrap
|
||||
// raw pointers from mlx::allocator are supported.
|
||||
class Buffer {
|
||||
private:
|
||||
void* ptr_;
|
||||
|
||||
public:
|
||||
Buffer(void* ptr) : ptr_(ptr){};
|
||||
|
||||
// Get the raw data pointer from the buffer
|
||||
void* raw_ptr();
|
||||
|
||||
// Get the buffer pointer from the buffer
|
||||
const void* ptr() const {
|
||||
return ptr_;
|
||||
};
|
||||
void* ptr() {
|
||||
return ptr_;
|
||||
};
|
||||
};
|
||||
|
||||
Buffer malloc(size_t size);
|
||||
|
||||
void free(Buffer buffer);
|
||||
|
||||
// Wait for running tasks to finish and free up memory
|
||||
// if allocation fails
|
||||
Buffer malloc_or_wait(size_t size);
|
||||
|
||||
class Allocator {
|
||||
/** Abstract base clase for a memory allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) = 0;
|
||||
virtual void free(Buffer buffer) = 0;
|
||||
|
||||
Allocator() = default;
|
||||
Allocator(const Allocator& other) = delete;
|
||||
Allocator(Allocator&& other) = delete;
|
||||
Allocator& operator=(const Allocator& other) = delete;
|
||||
Allocator& operator=(Allocator&& other) = delete;
|
||||
virtual ~Allocator() = default;
|
||||
};
|
||||
|
||||
Allocator& allocator();
|
||||
|
||||
class CommonAllocator : public Allocator {
|
||||
/** A general CPU allocator. */
|
||||
public:
|
||||
virtual Buffer malloc(size_t size) override;
|
||||
virtual void free(Buffer buffer) override;
|
||||
|
||||
private:
|
||||
CommonAllocator() = default;
|
||||
friend Allocator& allocator();
|
||||
};
|
||||
|
||||
} // namespace mlx::core::allocator
|
18
mlx/backend/accelerate/conv.cpp
Normal file
18
mlx/backend/accelerate/conv.cpp
Normal file
@ -0,0 +1,18 @@
|
||||
#include <cassert>
|
||||
|
||||
#include <simd/vector.h>
|
||||
#include <vecLib/vDSP.h>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
|
||||
// TODO: Add accelerate based optimizations for CPU conv
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
26
mlx/backend/accelerate/utils.h
Normal file
26
mlx/backend/accelerate/utils.h
Normal file
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include <vecLib/BNNS/bnns.h>
|
||||
#include "mlx/dtype.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
BNNSDataType to_bnns_dtype(Dtype mlx_dtype) {
|
||||
uint32_t size_bits = size_of(mlx_dtype) * 8;
|
||||
switch (kindof(mlx_dtype)) {
|
||||
case Dtype::Kind::b:
|
||||
return BNNSDataTypeBoolean;
|
||||
case Dtype::Kind::u:
|
||||
return BNNSDataType(BNNSDataTypeUIntBit | size_bits);
|
||||
case Dtype::Kind::i:
|
||||
return BNNSDataType(BNNSDataTypeIntBit | size_bits);
|
||||
case Dtype::Kind::f:
|
||||
return BNNSDataType(BNNSDataTypeFloatBit | size_bits);
|
||||
case Dtype::Kind::V:
|
||||
return BNNSDataTypeBFloat16;
|
||||
case Dtype::Kind::c:
|
||||
throw std::invalid_argument("BNNS does not support complex types");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
27
mlx/backend/common/copy.h
Normal file
27
mlx/backend/common/copy.h
Normal file
@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
enum class CopyType {
|
||||
// Copy a raw scalar input into the full contiguous output
|
||||
Scalar,
|
||||
|
||||
// Copy the raw input buffer contiguously into a raw output buffer of the same
|
||||
// size
|
||||
Vector,
|
||||
|
||||
// Copy the full virtual input to the full contiguous output
|
||||
General,
|
||||
|
||||
// Copy the full virtual input to the full virtual output. We assume the
|
||||
// input and output have the same shape.
|
||||
GeneralGeneral
|
||||
};
|
||||
|
||||
void copy(const array& src, array& dst, CopyType ctype);
|
||||
void copy_inplace(const array& src, array& dst, CopyType ctype);
|
||||
|
||||
} // namespace mlx::core
|
394
mlx/backend/common/sort.cpp
Normal file
394
mlx/backend/common/sort.cpp
Normal file
@ -0,0 +1,394 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "mlx/backend/common/copy.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename IdxT = int32_t>
|
||||
struct StridedIterator {
|
||||
using iterator_category = std::random_access_iterator_tag;
|
||||
using difference_type = IdxT;
|
||||
using value_type = T;
|
||||
using reference = value_type&;
|
||||
using pointer = value_type*;
|
||||
|
||||
// Constructors
|
||||
StridedIterator() = default;
|
||||
|
||||
explicit StridedIterator(T* ptr, size_t stride, difference_type offset = 0)
|
||||
: ptr_(ptr + offset * stride), stride_(stride) {}
|
||||
|
||||
explicit StridedIterator(array& arr, int axis, difference_type offset = 0)
|
||||
: StridedIterator(arr.data<T>(), arr.strides()[axis], offset) {}
|
||||
|
||||
// Accessors
|
||||
reference operator*() const {
|
||||
return ptr_[0];
|
||||
}
|
||||
|
||||
reference operator[](difference_type idx) const {
|
||||
return ptr_[idx * stride_];
|
||||
}
|
||||
|
||||
// Comparisons
|
||||
bool operator==(const StridedIterator& other) const {
|
||||
return ptr_ == other.ptr_ && stride_ == other.stride_;
|
||||
}
|
||||
|
||||
bool operator!=(const StridedIterator& other) const {
|
||||
return ptr_ != other.ptr_;
|
||||
}
|
||||
|
||||
bool operator<(const StridedIterator& other) const {
|
||||
return ptr_ < other.ptr_;
|
||||
}
|
||||
|
||||
bool operator>(const StridedIterator& other) const {
|
||||
return ptr_ > other.ptr_;
|
||||
}
|
||||
|
||||
bool operator<=(const StridedIterator& other) const {
|
||||
return ptr_ <= other.ptr_;
|
||||
}
|
||||
|
||||
bool operator>=(const StridedIterator& other) const {
|
||||
return ptr_ >= other.ptr_;
|
||||
}
|
||||
|
||||
difference_type operator-(const StridedIterator& other) const {
|
||||
return (ptr_ - other.ptr_) / stride_;
|
||||
}
|
||||
|
||||
// Moving
|
||||
StridedIterator& operator++() {
|
||||
ptr_ += stride_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
StridedIterator& operator--() {
|
||||
ptr_ -= stride_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
StridedIterator& operator+=(difference_type diff) {
|
||||
ptr_ += diff * stride_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
StridedIterator& operator-=(difference_type diff) {
|
||||
ptr_ -= diff * stride_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
StridedIterator operator+(difference_type diff) {
|
||||
return StridedIterator(ptr_, stride_, diff);
|
||||
}
|
||||
|
||||
StridedIterator operator-(difference_type diff) {
|
||||
return StridedIterator(ptr_, stride_, -diff);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t stride_;
|
||||
T* ptr_;
|
||||
};
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void sort(const array& in, array& out, int axis) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting in place
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argsort(const array& in, array& out, int axis) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
// Perform sorting
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||
|
||||
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||
|
||||
std::stable_sort(st, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * axis_stride];
|
||||
auto v2 = data_ptr[b * axis_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void partition(const array& in, array& out, int axis, int kth) {
|
||||
// Copy input to output
|
||||
CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||
copy(in, out, ctype);
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition in place
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
T* data_ptr = out.data<T>() + loc;
|
||||
|
||||
StridedIterator st(data_ptr, axis_stride, 0);
|
||||
StridedIterator md(data_ptr, axis_stride, kth);
|
||||
StridedIterator ed(data_ptr, axis_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = uint32_t>
|
||||
void argpartition(const array& in, array& out, int axis, int kth) {
|
||||
// Allocate output
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Get axis, shape and stride info
|
||||
axis = axis < 0 ? axis + in.ndim() : axis;
|
||||
size_t n_rows = in.size() / in.shape(axis);
|
||||
|
||||
auto remaining_shape = in.shape();
|
||||
remaining_shape.erase(remaining_shape.begin() + axis);
|
||||
|
||||
auto remaining_strides = in.strides();
|
||||
remaining_strides.erase(remaining_strides.begin() + axis);
|
||||
|
||||
size_t axis_stride = in.strides()[axis];
|
||||
int axis_size = in.shape(axis);
|
||||
|
||||
kth = kth < 0 ? kth + axis_size : kth;
|
||||
|
||||
// Perform partition
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
size_t loc = elem_to_loc(i, remaining_shape, remaining_strides);
|
||||
const T* data_ptr = in.data<T>() + loc;
|
||||
IdxT* idx_ptr = out.data<IdxT>() + loc;
|
||||
|
||||
StridedIterator st_(idx_ptr, axis_stride, 0);
|
||||
StridedIterator ed_(idx_ptr, axis_stride, axis_size);
|
||||
|
||||
// Initialize with iota
|
||||
std::iota(st_, ed_, IdxT(0));
|
||||
|
||||
// Sort according to vals
|
||||
StridedIterator st(idx_ptr, axis_stride, 0);
|
||||
StridedIterator md(idx_ptr, axis_stride, kth);
|
||||
StridedIterator ed(idx_ptr, axis_stride, axis_size);
|
||||
|
||||
std::nth_element(st, md, ed, [data_ptr, axis_stride](IdxT a, IdxT b) {
|
||||
auto v1 = data_ptr[a * axis_stride];
|
||||
auto v2 = data_ptr[b * axis_stride];
|
||||
return v1 < v2 || (v1 == v2 && a < b);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgSort::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argsort<bool>(in, out, axis_);
|
||||
case uint8:
|
||||
return argsort<uint8_t>(in, out, axis_);
|
||||
case uint16:
|
||||
return argsort<uint16_t>(in, out, axis_);
|
||||
case uint32:
|
||||
return argsort<uint32_t>(in, out, axis_);
|
||||
case uint64:
|
||||
return argsort<uint64_t>(in, out, axis_);
|
||||
case int8:
|
||||
return argsort<int8_t>(in, out, axis_);
|
||||
case int16:
|
||||
return argsort<int16_t>(in, out, axis_);
|
||||
case int32:
|
||||
return argsort<int32_t>(in, out, axis_);
|
||||
case int64:
|
||||
return argsort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return argsort<float>(in, out, axis_);
|
||||
case float16:
|
||||
return argsort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
return argsort<bfloat16_t>(in, out, axis_);
|
||||
case complex64:
|
||||
return argsort<complex64_t>(in, out, axis_);
|
||||
}
|
||||
}
|
||||
|
||||
void Sort::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(in, out, axis_);
|
||||
case uint8:
|
||||
return sort<uint8_t>(in, out, axis_);
|
||||
case uint16:
|
||||
return sort<uint16_t>(in, out, axis_);
|
||||
case uint32:
|
||||
return sort<uint32_t>(in, out, axis_);
|
||||
case uint64:
|
||||
return sort<uint64_t>(in, out, axis_);
|
||||
case int8:
|
||||
return sort<int8_t>(in, out, axis_);
|
||||
case int16:
|
||||
return sort<int16_t>(in, out, axis_);
|
||||
case int32:
|
||||
return sort<int32_t>(in, out, axis_);
|
||||
case int64:
|
||||
return sort<int64_t>(in, out, axis_);
|
||||
case float32:
|
||||
return sort<float>(in, out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(in, out, axis_);
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(in, out, axis_);
|
||||
case complex64:
|
||||
return sort<complex64_t>(in, out, axis_);
|
||||
}
|
||||
}
|
||||
|
||||
void ArgPartition::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return argpartition<bool>(in, out, axis_, kth_);
|
||||
case uint8:
|
||||
return argpartition<uint8_t>(in, out, axis_, kth_);
|
||||
case uint16:
|
||||
return argpartition<uint16_t>(in, out, axis_, kth_);
|
||||
case uint32:
|
||||
return argpartition<uint32_t>(in, out, axis_, kth_);
|
||||
case uint64:
|
||||
return argpartition<uint64_t>(in, out, axis_, kth_);
|
||||
case int8:
|
||||
return argpartition<int8_t>(in, out, axis_, kth_);
|
||||
case int16:
|
||||
return argpartition<int16_t>(in, out, axis_, kth_);
|
||||
case int32:
|
||||
return argpartition<int32_t>(in, out, axis_, kth_);
|
||||
case int64:
|
||||
return argpartition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return argpartition<float>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return argpartition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return argpartition<bfloat16_t>(in, out, axis_, kth_);
|
||||
case complex64:
|
||||
return argpartition<complex64_t>(in, out, axis_, kth_);
|
||||
}
|
||||
}
|
||||
|
||||
void Partition::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
switch (in.dtype()) {
|
||||
case bool_:
|
||||
return partition<bool>(in, out, axis_, kth_);
|
||||
case uint8:
|
||||
return partition<uint8_t>(in, out, axis_, kth_);
|
||||
case uint16:
|
||||
return partition<uint16_t>(in, out, axis_, kth_);
|
||||
case uint32:
|
||||
return partition<uint32_t>(in, out, axis_, kth_);
|
||||
case uint64:
|
||||
return partition<uint64_t>(in, out, axis_, kth_);
|
||||
case int8:
|
||||
return partition<int8_t>(in, out, axis_, kth_);
|
||||
case int16:
|
||||
return partition<int16_t>(in, out, axis_, kth_);
|
||||
case int32:
|
||||
return partition<int32_t>(in, out, axis_, kth_);
|
||||
case int64:
|
||||
return partition<int64_t>(in, out, axis_, kth_);
|
||||
case float32:
|
||||
return partition<float>(in, out, axis_, kth_);
|
||||
case float16:
|
||||
return partition<float16_t>(in, out, axis_, kth_);
|
||||
case bfloat16:
|
||||
return partition<bfloat16_t>(in, out, axis_, kth_);
|
||||
case complex64:
|
||||
return partition<complex64_t>(in, out, axis_, kth_);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
257
mlx/backend/metal/device.cpp
Normal file
257
mlx/backend/metal/device.cpp
Normal file
@ -0,0 +1,257 @@
|
||||
#include <dlfcn.h>
|
||||
#include <cstdlib>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
|
||||
#define NS_PRIVATE_IMPLEMENTATION
|
||||
#define CA_PRIVATE_IMPLEMENTATION
|
||||
#define MTL_PRIVATE_IMPLEMENTATION
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/backend/metal/mps/gemm.h"
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
static Device metal_device_;
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
static constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
|
||||
static constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
auto load_device() {
|
||||
MTL::Device* device = MTL::CreateSystemDefaultDevice();
|
||||
if (!device) {
|
||||
throw std::runtime_error("Failed to load device");
|
||||
}
|
||||
return device;
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||
MTL::Device* device,
|
||||
const char* path) {
|
||||
auto library = NS::String::string(path, NS::UTF8StringEncoding);
|
||||
NS::Error* error;
|
||||
auto lib = device->newLibrary(library, &error);
|
||||
|
||||
return std::make_pair(lib, error);
|
||||
}
|
||||
|
||||
MTL::Library* load_library(
|
||||
MTL::Device* device,
|
||||
const std::string& lib_name = "mlx",
|
||||
const char* lib_path = default_mtllib_path) {
|
||||
// Firstly, search for the metallib in the same path as this binary
|
||||
std::string first_path = get_colocated_mtllib_path(lib_name);
|
||||
if (first_path.size() != 0) {
|
||||
auto [lib, error] = load_library_from_path(device, first_path.c_str());
|
||||
if (lib) {
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
|
||||
// Couldn't find it so let's load it from default_mtllib_path
|
||||
{
|
||||
auto [lib, error] = load_library_from_path(device, lib_path);
|
||||
if (!lib) {
|
||||
std::ostringstream msg;
|
||||
msg << error->localizedDescription()->utf8String() << "\n"
|
||||
<< "Failed to load device library from <" << lib_path << ">"
|
||||
<< " or <" << first_path << ">.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
return lib;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Device::Device()
|
||||
: pool_(NS::AutoreleasePool::alloc()->init()),
|
||||
device_(load_device()),
|
||||
library_map_({{"mlx", load_library(device_)}}) {}
|
||||
|
||||
Device::~Device() {
|
||||
for (auto& q : queue_map_) {
|
||||
q.second->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
for (auto& l : library_map_) {
|
||||
l.second->release();
|
||||
}
|
||||
device_->release();
|
||||
pool_->release();
|
||||
}
|
||||
|
||||
void Device::new_queue(int index) {
|
||||
// Multiple threads can ask the device for queues
|
||||
// We lock this as a critical section for safety
|
||||
const std::lock_guard<std::mutex> lock(mtx_);
|
||||
auto q = device_->newCommandQueue(MAX_BUFFERS_PER_QUEUE);
|
||||
if (!q) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
}
|
||||
queue_map_.insert({index, q});
|
||||
}
|
||||
|
||||
int Device::get_command_buffer_ops(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
return bit->second.first;
|
||||
}
|
||||
|
||||
void Device::increment_command_buffer_ops(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
bit->second.first++;
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::new_command_buffer(int index) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
}
|
||||
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
return buffer_map_.insert({index, {0, cb}}).first->second.second;
|
||||
}
|
||||
|
||||
void Device::commit_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
bit->second.second->commit();
|
||||
bit->second.second->release();
|
||||
buffer_map_.erase(bit);
|
||||
}
|
||||
|
||||
void Device::end_encoding(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit != encoder_map_.end()) {
|
||||
eit->second->endEncoding();
|
||||
eit->second->release();
|
||||
encoder_map_.erase(eit);
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
auto cb = get_command_buffer(index);
|
||||
auto compute_encoder = cb->computeCommandEncoder();
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
compute_encoder->retain();
|
||||
eit = encoder_map_.insert({index, compute_encoder}).first;
|
||||
}
|
||||
return eit->second;
|
||||
}
|
||||
|
||||
MTL::ArgumentEncoder* Device::argument_encoder(
|
||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const {
|
||||
// NB array here is already autoreleased but the returned argument
|
||||
// encoder is owned by the caller and must be released/autoreleased
|
||||
NS::Array* arg_desc_arr = NS::Array::array(
|
||||
reinterpret_cast<NS::Object* const*>(arg_descs.data()), arg_descs.size());
|
||||
return device_->newArgumentEncoder(arg_desc_arr);
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::string& lib_path) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
auto new_lib = load_library(device_, lib_name, lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
}
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
const std::string& lib_name,
|
||||
const std::function<std::string(const std::string&)>& lib_path_func) {
|
||||
if (auto it = library_map_.find(lib_name); it == library_map_.end()) {
|
||||
std::string new_lib_path = lib_path_func(lib_name);
|
||||
auto new_lib = load_library(device_, lib_name, new_lib_path.c_str());
|
||||
library_map_.insert({lib_name, new_lib});
|
||||
}
|
||||
}
|
||||
|
||||
MTL::ComputePipelineState* Device::get_kernel(
|
||||
const std::string& name,
|
||||
const std::string& lib_name /* = "mlx" */) {
|
||||
// Look for cached kernel
|
||||
if (auto it = kernel_map_.find(name); it != kernel_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Prepare new kernel
|
||||
|
||||
// Search for cached metal lib
|
||||
MTL::Library* mtl_lib;
|
||||
if (auto it = library_map_.find(name); it != library_map_.end()) {
|
||||
mtl_lib = it->second;
|
||||
} else { // Look for metallib alongside library
|
||||
register_library(lib_name);
|
||||
mtl_lib = library_map_[lib_name];
|
||||
}
|
||||
|
||||
// Pull kernel from library
|
||||
auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
|
||||
auto mtl_function = mtl_lib->newFunction(ns_name);
|
||||
|
||||
// Compile kernel to compute pipeline
|
||||
NS::Error* error = nullptr;
|
||||
MTL::ComputePipelineState* kernel;
|
||||
if (mtl_function) {
|
||||
kernel = device_->newComputePipelineState(mtl_function, &error);
|
||||
mtl_function->release();
|
||||
}
|
||||
if (!mtl_function || !kernel) {
|
||||
std::ostringstream msg;
|
||||
msg << "[metal::Device] Unable to load kernel " << name << "\n";
|
||||
if (error) {
|
||||
msg << error->localizedDescription()->utf8String() << "\n";
|
||||
}
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// Add kernel to cache
|
||||
kernel_map_.insert({name, kernel});
|
||||
return kernel;
|
||||
}
|
||||
|
||||
Device& device(mlx::core::Device) {
|
||||
return metal_device_;
|
||||
}
|
||||
|
||||
NS::AutoreleasePool*& thread_autorelease_pool() {
|
||||
static thread_local NS::AutoreleasePool* p =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
return p;
|
||||
}
|
||||
|
||||
void new_stream(Stream stream) {
|
||||
thread_autorelease_pool();
|
||||
if (stream.device == mlx::core::Device::gpu) {
|
||||
device(stream.device).new_queue(stream.index);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
10
mlx/backend/metal/fft.cpp
Normal file
10
mlx/backend/metal/fft.cpp
Normal file
@ -0,0 +1,10 @@
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
throw std::runtime_error("[FFT] NYI for Metal backend.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
83
mlx/backend/metal/kernels/CMakeLists.txt
Normal file
83
mlx/backend/metal/kernels/CMakeLists.txt
Normal file
@ -0,0 +1,83 @@
|
||||
set(
|
||||
HEADERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bf16_math.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/complex.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/defines.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/erf.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.h
|
||||
)
|
||||
|
||||
set(
|
||||
KERNELS
|
||||
"arange"
|
||||
"arg_reduce"
|
||||
"binary"
|
||||
"conv"
|
||||
"copy"
|
||||
"gemm"
|
||||
"gemv"
|
||||
"random"
|
||||
"reduce"
|
||||
"scan"
|
||||
"softmax"
|
||||
"sort"
|
||||
"unary"
|
||||
"indexing"
|
||||
)
|
||||
|
||||
function(build_kernel KERNEL)
|
||||
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
|
||||
set(HEADERS_PADDED ${HEADERS})
|
||||
if(${KERNEL} STREQUAL "gemm")
|
||||
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h)
|
||||
endif()
|
||||
if(${KERNEL} STREQUAL "conv")
|
||||
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h)
|
||||
endif()
|
||||
add_custom_command(
|
||||
COMMAND xcrun -sdk macosx metal -Wall -Wextra
|
||||
-fno-fast-math
|
||||
-c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR}
|
||||
-o ${KERNEL}.air
|
||||
DEPENDS ${SRCFILE} ${HEADERS_PADDED}
|
||||
OUTPUT ${KERNEL}.air
|
||||
COMMENT "Building ${KERNEL}.air"
|
||||
VERBATIM
|
||||
)
|
||||
endfunction(build_kernel)
|
||||
|
||||
foreach(KERNEL ${KERNELS})
|
||||
build_kernel(${KERNEL})
|
||||
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
|
||||
endforeach()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
|
||||
DEPENDS ${KERNEL_AIR}
|
||||
COMMENT "Building mlx.metallib"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(
|
||||
mlx-metallib
|
||||
DEPENDS
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
)
|
||||
|
||||
add_dependencies(
|
||||
mlx
|
||||
mlx-metallib
|
||||
)
|
||||
|
||||
# Install metallib
|
||||
include(GNUInstallDirs)
|
||||
|
||||
install(
|
||||
FILES ${MLX_METAL_PATH}/mlx.metallib
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
COMPONENT metallib
|
||||
)
|
17
mlx/backend/metal/kernels/conv_params.h
Normal file
17
mlx/backend/metal/kernels/conv_params.h
Normal file
@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
template <int NDIM>
|
||||
struct MLXConvParams {
|
||||
const int N; // Batch size
|
||||
const int C; // In channels
|
||||
const int O; // Out channels
|
||||
const int iS[NDIM]; // Input spatial dim
|
||||
const int wS[NDIM]; // Weight spatial dim
|
||||
const int oS[NDIM]; // Output spatial dim
|
||||
const int str[NDIM]; // Kernel strides
|
||||
const int pad[NDIM]; // Input padding
|
||||
const int dil[NDIM]; // Kernel dilation
|
||||
const size_t in_strides[NDIM + 2]; // In strides
|
||||
const size_t wt_strides[NDIM + 2]; // Wt strides
|
||||
const size_t out_strides[NDIM + 2]; // Out strides
|
||||
};
|
253
mlx/backend/metal/kernels/indexing.metal
Normal file
253
mlx/backend/metal/kernels/indexing.metal
Normal file
@ -0,0 +1,253 @@
|
||||
#include <metal_atomic>
|
||||
#include <metal_texture>
|
||||
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/reduce.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Gather kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IdxT, int NIDX>
|
||||
struct Indices {
|
||||
const array<device IdxT*, NIDX> buffers [[id(0)]];
|
||||
device int* shapes [[id(NIDX + 1)]];
|
||||
device size_t* strides [[id(NIDX + 2)]];
|
||||
const int ndim [[id(NIDX + 3)]];
|
||||
};
|
||||
|
||||
template <typename IdxT>
|
||||
inline size_t offset_neg_idx(IdxT idx, size_t size) {
|
||||
return (idx < 0) ? idx + size : idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(bool idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline size_t offset_neg_idx(uint32_t idx, size_t) {
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, int NIDX>
|
||||
[[kernel]] void gather(
|
||||
const device T *src [[buffer(0)]],
|
||||
const device Indices<IdxT, NIDX>& indices [[buffer(1)]],
|
||||
device T *out [[buffer(2)]],
|
||||
const device int *src_shape [[buffer(3)]],
|
||||
const device size_t *src_strides [[buffer(4)]],
|
||||
const device size_t& src_ndim [[buffer(5)]],
|
||||
const device int *slice_sizes [[buffer(6)]],
|
||||
const device size_t& slice_size [[buffer(7)]],
|
||||
const device int *axes [[buffer(8)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
|
||||
auto ind_idx = gid / slice_size;
|
||||
auto ind_offset = gid % slice_size;
|
||||
|
||||
size_t src_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], src_shape[ax]);
|
||||
src_idx += idx_val * src_strides[ax];
|
||||
}
|
||||
|
||||
auto src_offset = elem_to_loc(
|
||||
ind_offset, slice_sizes, src_strides, src_ndim);
|
||||
out[gid] = src[src_idx + src_offset];
|
||||
}
|
||||
|
||||
#define instantiate_gather4(name, src_type, ind_type, nindex) \
|
||||
template [[host_name("gather" name "_" #nindex)]] \
|
||||
[[kernel]] void gather( \
|
||||
const device src_type *src [[buffer(0)]], \
|
||||
const device Indices<ind_type, nindex>& indices [[buffer(1)]], \
|
||||
device src_type *out [[buffer(2)]], \
|
||||
const device int *src_shape [[buffer(3)]], \
|
||||
const device size_t *src_strides [[buffer(4)]], \
|
||||
const device size_t& src_ndim [[buffer(5)]], \
|
||||
const device int *slice_sizes [[buffer(6)]], \
|
||||
const device size_t& slice_size [[buffer(7)]], \
|
||||
const device int* axes [[buffer(8)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special for case NIDX=0
|
||||
instantiate_gather4("bool_", bool, bool, 0)
|
||||
instantiate_gather4("uint8", uint8_t, bool, 0)
|
||||
instantiate_gather4("uint16", uint16_t, bool, 0)
|
||||
instantiate_gather4("uint32", uint32_t, bool, 0)
|
||||
instantiate_gather4("uint64", uint64_t, bool, 0)
|
||||
instantiate_gather4("int8", int8_t, bool, 0)
|
||||
instantiate_gather4("int16", int16_t, bool, 0)
|
||||
instantiate_gather4("int32", int32_t, bool, 0)
|
||||
instantiate_gather4("int64", int64_t, bool, 0)
|
||||
instantiate_gather4("float16", half, bool, 0)
|
||||
instantiate_gather4("float32", float, bool, 0)
|
||||
instantiate_gather4("bfloat16", bfloat16_t, bool, 0)
|
||||
|
||||
#define instantiate_gather3(name, src_type, ind_type) \
|
||||
instantiate_gather4(name, src_type, ind_type, 1) \
|
||||
instantiate_gather4(name, src_type, ind_type, 2) \
|
||||
instantiate_gather4(name, src_type, ind_type, 3) \
|
||||
instantiate_gather4(name, src_type, ind_type, 4) \
|
||||
instantiate_gather4(name, src_type, ind_type, 5) \
|
||||
instantiate_gather4(name, src_type, ind_type, 6) \
|
||||
instantiate_gather4(name, src_type, ind_type, 7) \
|
||||
instantiate_gather4(name, src_type, ind_type, 8) \
|
||||
instantiate_gather4(name, src_type, ind_type, 9) \
|
||||
instantiate_gather4(name, src_type, ind_type, 10)
|
||||
|
||||
#define instantiate_gather(name, src_type) \
|
||||
instantiate_gather3(#name "bool_", src_type, bool) \
|
||||
instantiate_gather3(#name "uint8", src_type, uint8_t) \
|
||||
instantiate_gather3(#name "uint16", src_type, uint16_t) \
|
||||
instantiate_gather3(#name "uint32", src_type, uint32_t) \
|
||||
instantiate_gather3(#name "uint64", src_type, uint64_t) \
|
||||
instantiate_gather3(#name "int8", src_type, int8_t) \
|
||||
instantiate_gather3(#name "int16", src_type, int16_t) \
|
||||
instantiate_gather3(#name "int32", src_type, int32_t) \
|
||||
instantiate_gather3(#name "int64", src_type, int64_t)
|
||||
|
||||
instantiate_gather(bool_, bool)
|
||||
instantiate_gather(uint8, uint8_t)
|
||||
instantiate_gather(uint16, uint16_t)
|
||||
instantiate_gather(uint32, uint32_t)
|
||||
instantiate_gather(uint64, uint64_t)
|
||||
instantiate_gather(int8, int8_t)
|
||||
instantiate_gather(int16, int16_t)
|
||||
instantiate_gather(int32, int32_t)
|
||||
instantiate_gather(int64, int64_t)
|
||||
instantiate_gather(float16, half)
|
||||
instantiate_gather(float32, float)
|
||||
instantiate_gather(bfloat16, bfloat16_t)
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
[[kernel]] void scatter(
|
||||
const device Indices<IdxT, NIDX>& indices [[buffer(0)]],
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const device int *upd_shape [[buffer(3)]],
|
||||
const device size_t *upd_strides [[buffer(4)]],
|
||||
const device size_t& upd_ndim [[buffer(5)]],
|
||||
const device size_t& upd_size [[buffer(6)]],
|
||||
const device int *out_shape [[buffer(7)]],
|
||||
const device size_t *out_strides [[buffer(8)]],
|
||||
const device size_t& out_ndim [[buffer(9)]],
|
||||
const device int* axes [[buffer(10)]],
|
||||
uint gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
auto ind_idx = gid / upd_size;
|
||||
auto ind_offset = gid % upd_size;
|
||||
|
||||
size_t out_idx = 0;
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
auto idx_loc = elem_to_loc(
|
||||
ind_idx,
|
||||
&indices.shapes[indices.ndim * i],
|
||||
&indices.strides[indices.ndim * i],
|
||||
indices.ndim);
|
||||
auto ax = axes[i];
|
||||
auto idx_val = offset_neg_idx(
|
||||
indices.buffers[i][idx_loc], out_shape[ax]);
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
auto upd_idx = elem_to_loc(gid, upd_shape, upd_strides, upd_ndim);
|
||||
|
||||
op.atomic_update(out + out_idx + out_offset, updates[upd_idx]);
|
||||
}
|
||||
|
||||
#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \
|
||||
template [[host_name("scatter" name "_" #nindex)]] \
|
||||
[[kernel]] void scatter<type, ind_type, op_type, nindex>( \
|
||||
const device Indices<ind_type, nindex>& indices [[buffer(0)]], \
|
||||
const device type *updates [[buffer(1)]], \
|
||||
device mlx_atomic<type> *out [[buffer(2)]], \
|
||||
const device int *upd_shape [[buffer(3)]], \
|
||||
const device size_t *upd_strides [[buffer(4)]], \
|
||||
const device size_t& upd_ndim [[buffer(5)]], \
|
||||
const device size_t& upd_size [[buffer(6)]], \
|
||||
const device int *out_shape [[buffer(7)]], \
|
||||
const device size_t *out_strides [[buffer(8)]], \
|
||||
const device size_t& out_ndim [[buffer(9)]], \
|
||||
const device int* axes [[buffer(10)]], \
|
||||
uint gid [[thread_position_in_grid]]);
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
instantiate_scatter4(#name "none", type, bool, None, 0) \
|
||||
instantiate_scatter4(#name "_sum", type, bool, Sum<type>, 0) \
|
||||
instantiate_scatter4(#name "_prod", type, bool, Prod<type>, 0) \
|
||||
instantiate_scatter4(#name "_max", type, bool, Max<type>, 0) \
|
||||
instantiate_scatter4(#name "_min", type, bool, Min<type>, 0)
|
||||
|
||||
#define instantiate_scatter3(name, type, ind_type, op_type) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 1) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 2) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 3) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 4) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 5) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 6) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 7) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 8) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 9) \
|
||||
instantiate_scatter4(name, type, ind_type, op_type, 10)
|
||||
|
||||
#define instantiate_scatter2(name, type, ind_type) \
|
||||
instantiate_scatter3(name "_none", type, ind_type, None) \
|
||||
instantiate_scatter3(name "_sum", type, ind_type, Sum<type>) \
|
||||
instantiate_scatter3(name "_prod", type, ind_type, Prod<type>) \
|
||||
instantiate_scatter3(name "_max", type, ind_type, Max<type>) \
|
||||
instantiate_scatter3(name "_min", type, ind_type, Min<type>)
|
||||
|
||||
#define instantiate_scatter(name, type) \
|
||||
instantiate_scatter2(#name "bool_", type, bool) \
|
||||
instantiate_scatter2(#name "uint8", type, uint8_t) \
|
||||
instantiate_scatter2(#name "uint16", type, uint16_t) \
|
||||
instantiate_scatter2(#name "uint32", type, uint32_t) \
|
||||
instantiate_scatter2(#name "uint64", type, uint64_t) \
|
||||
instantiate_scatter2(#name "int8", type, int8_t) \
|
||||
instantiate_scatter2(#name "int16", type, int16_t) \
|
||||
instantiate_scatter2(#name "int32", type, int32_t) \
|
||||
instantiate_scatter2(#name "int64", type, int64_t)
|
||||
|
||||
// TODO uint64 and int64 unsupported
|
||||
instantiate_scatter_nd0(bool_, bool)
|
||||
instantiate_scatter_nd0(uint8, uint8_t)
|
||||
instantiate_scatter_nd0(uint16, uint16_t)
|
||||
instantiate_scatter_nd0(uint32, uint32_t)
|
||||
instantiate_scatter_nd0(int8, int8_t)
|
||||
instantiate_scatter_nd0(int16, int16_t)
|
||||
instantiate_scatter_nd0(int32, int32_t)
|
||||
instantiate_scatter_nd0(float16, half)
|
||||
instantiate_scatter_nd0(float32, float)
|
||||
instantiate_scatter_nd0(bfloat16, bfloat16_t)
|
||||
|
||||
instantiate_scatter(bool_, bool)
|
||||
instantiate_scatter(uint8, uint8_t)
|
||||
instantiate_scatter(uint16, uint16_t)
|
||||
instantiate_scatter(uint32, uint32_t)
|
||||
instantiate_scatter(int8, int8_t)
|
||||
instantiate_scatter(int16, int16_t)
|
||||
instantiate_scatter(int32, int32_t)
|
||||
instantiate_scatter(float16, half)
|
||||
instantiate_scatter(float32, float)
|
||||
instantiate_scatter(bfloat16, bfloat16_t)
|
174
mlx/backend/metal/kernels/reduce.h
Normal file
174
mlx/backend/metal/kernels/reduce.h
Normal file
@ -0,0 +1,174 @@
|
||||
#pragma once
|
||||
|
||||
#include <metal_atomic>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/atomic.h"
|
||||
#include "mlx/backend/metal/kernels/bf16.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
union bool4_or_uint {
|
||||
bool4 b;
|
||||
unsigned int i;
|
||||
};
|
||||
|
||||
struct None {
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
};
|
||||
|
||||
struct And {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_all(val);
|
||||
};
|
||||
|
||||
static constexpr constant bool init = true;
|
||||
|
||||
void atomic_update(
|
||||
device mlx_atomic<unsigned int>* out,
|
||||
bool val,
|
||||
int elem_idx,
|
||||
int offset = 0) {
|
||||
if (!val) {
|
||||
bool4_or_uint update;
|
||||
update.b = {true, true, true, true};
|
||||
update.b[elem_idx] = false;
|
||||
mlx_atomic_fetch_and_explicit(out, update.i, offset);
|
||||
}
|
||||
}
|
||||
|
||||
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
||||
if (!val) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Non atomic update
|
||||
void update(device bool* out, bool val) {
|
||||
*out &= val;
|
||||
}
|
||||
|
||||
// Operator
|
||||
bool operator()(bool a, bool b) {
|
||||
return a && b;
|
||||
}
|
||||
};
|
||||
|
||||
struct Or {
|
||||
bool simd_reduce(bool val) {
|
||||
return simd_any(val);
|
||||
};
|
||||
|
||||
static constexpr constant bool init = false;
|
||||
|
||||
void atomic_update(
|
||||
device mlx_atomic<unsigned int>* out,
|
||||
bool val,
|
||||
int elem_idx,
|
||||
int offset = 0) {
|
||||
if (val) {
|
||||
bool4_or_uint update;
|
||||
update.b = {false, false, false, false};
|
||||
update.b[elem_idx] = true;
|
||||
mlx_atomic_fetch_or_explicit(out, update.i, offset);
|
||||
}
|
||||
}
|
||||
|
||||
void atomic_update(device mlx_atomic<bool>* out, bool val, int offset = 0) {
|
||||
if (val) {
|
||||
mlx_atomic_store_explicit(out, val, offset);
|
||||
}
|
||||
}
|
||||
|
||||
// Non atomic update
|
||||
void update(device bool* out, bool val) {
|
||||
*out |= val;
|
||||
}
|
||||
|
||||
// Operator
|
||||
bool operator()(bool a, bool b) {
|
||||
return a || b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct Sum {
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_sum(val);
|
||||
};
|
||||
|
||||
static constexpr constant U init = U(0);
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||
mlx_atomic_fetch_add_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct Prod {
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_product(val);
|
||||
};
|
||||
|
||||
static constexpr constant U init = U(1);
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||
mlx_atomic_fetch_mul_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct Min {
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_min(val);
|
||||
};
|
||||
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||
mlx_atomic_fetch_min_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct Max {
|
||||
template <typename T>
|
||||
T simd_reduce(T val) {
|
||||
return simd_max(val);
|
||||
};
|
||||
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
template <typename T>
|
||||
void atomic_update(device mlx_atomic<T>* out, T val, int offset = 0) {
|
||||
mlx_atomic_fetch_max_explicit(out, val, offset);
|
||||
}
|
||||
|
||||
// Operator
|
||||
U operator()(U a, U b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
};
|
492
mlx/backend/metal/kernels/scan.metal
Normal file
492
mlx/backend/metal/kernels/scan.metal
Normal file
@ -0,0 +1,492 @@
|
||||
#include <metal_math>
|
||||
#include <metal_simdgroup>
|
||||
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/kernels/utils.h"
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template <typename U>
|
||||
struct CumSum {
|
||||
static constexpr constant U init = static_cast<U>(0);
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
return simd_prefix_inclusive_sum(x);
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
return simd_prefix_exclusive_sum(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumProd {
|
||||
static constexpr constant U init = static_cast<U>(1.0f);
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
return simd_prefix_inclusive_product(x);
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
return simd_prefix_exclusive_product(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CumProd<bool> {
|
||||
static constexpr constant bool init = true;
|
||||
|
||||
template <typename T>
|
||||
bool operator()(bool a, T b) {
|
||||
return a & static_cast<bool>(b);
|
||||
}
|
||||
|
||||
bool simd_scan(bool x) {
|
||||
for (int i=1; i<=16; i*=2) {
|
||||
bool other = simd_shuffle_up(x, i);
|
||||
x &= other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
bool simd_exclusive_scan(bool x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumMax {
|
||||
static constexpr constant U init = Limits<U>::min;
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return (a >= b) ? a : b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i=1; i<=16; i*=2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
x = (x >= other) ? x : other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename U>
|
||||
struct CumMin {
|
||||
static constexpr constant U init = Limits<U>::max;
|
||||
|
||||
template <typename T>
|
||||
U operator()(U a, T b) {
|
||||
return (a <= b) ? a : b;
|
||||
}
|
||||
|
||||
U simd_scan(U x) {
|
||||
for (int i=1; i<=16; i*=2) {
|
||||
U other = simd_shuffle_up(x, i);
|
||||
x = (x <= other) ? x : other;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
U simd_exclusive_scan(U x) {
|
||||
x = simd_scan(x);
|
||||
return simd_shuffle_and_fill_up(x, init, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_unsafe(U values[N_READS], const device T * input) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[N_READS-i-1] = input[i];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = input[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, int N_READS, bool reverse>
|
||||
inline void load_safe(U values[N_READS], const device T * input, int start, int total, U init) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[N_READS-i-1] = (start + N_READS - i - 1 < total) ? input[i] : init;
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = (start + i < total) ? input[i] : init;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N_READS, bool reverse>
|
||||
inline void write_unsafe(U values[N_READS], device U * out) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = values[N_READS-i-1];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N_READS, bool reverse>
|
||||
inline void write_safe(U values[N_READS], device U * out, int start, int total) {
|
||||
if (reverse) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
if (start + N_READS - i - 1 < total) {
|
||||
out[i] = values[N_READS-i-1];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
if (start + i < total) {
|
||||
out[i] = values[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
||||
[[kernel]] void contiguous_scan(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t & axis_size [[buffer(2)]],
|
||||
uint gid [[thread_position_in_grid]],
|
||||
uint lid [[thread_position_in_threadgroup]],
|
||||
uint lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]],
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
||||
Op op;
|
||||
|
||||
// Position the pointers
|
||||
in += (gid / lsize) * axis_size;
|
||||
out += (gid / lsize) * axis_size;
|
||||
|
||||
// Compute the number of simd_groups
|
||||
uint simd_groups = lsize / simd_size;
|
||||
|
||||
// Allocate memory
|
||||
U prefix = Op::init;
|
||||
U values[N_READS];
|
||||
threadgroup U simdgroup_sums[32];
|
||||
|
||||
// Loop over the reduced axis in blocks of size ceildiv(axis_size, N_READS*lsize)
|
||||
// Read block
|
||||
// Compute inclusive scan of the block
|
||||
// Compute inclusive scan per thread
|
||||
// Compute exclusive scan of thread sums in simdgroup
|
||||
// Write simdgroup sums in SM
|
||||
// Compute exclusive scan of simdgroup sums
|
||||
// Compute the output by scanning prefix, prev_simdgroup, prev_thread, value
|
||||
// Write block
|
||||
|
||||
for (uint r = 0; r < ceildiv(axis_size, N_READS*lsize); r++) {
|
||||
// Compute the block offset
|
||||
uint offset = r*lsize*N_READS + lid*N_READS;
|
||||
|
||||
// Read the values
|
||||
if (reverse) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
load_unsafe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS);
|
||||
} else {
|
||||
load_safe<T, U, N_READS, reverse>(values, in + axis_size - offset - N_READS, offset, axis_size, Op::init);
|
||||
}
|
||||
} else {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
load_unsafe<T, U, N_READS, reverse>(values, in + offset);
|
||||
} else {
|
||||
load_safe<T, U, N_READS, reverse>(values, in + offset, offset, axis_size, Op::init);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute an inclusive scan per thread
|
||||
for (int i=1; i<N_READS; i++) {
|
||||
values[i] = op(values[i], values[i-1]);
|
||||
}
|
||||
|
||||
// Compute exclusive scan of thread sums
|
||||
U prev_thread = op.simd_exclusive_scan(values[N_READS-1]);
|
||||
|
||||
// Write simdgroup_sums to SM
|
||||
if (simd_lane_id == simd_size - 1) {
|
||||
simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Compute exclusive scan of simdgroup_sums
|
||||
if (simd_group_id == 0) {
|
||||
U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
|
||||
simdgroup_sums[simd_lane_id] = prev_simdgroup;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Compute the output
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = op(values[i], prefix);
|
||||
values[i] = op(values[i], simdgroup_sums[simd_group_id]);
|
||||
values[i] = op(values[i], prev_thread);
|
||||
}
|
||||
|
||||
// Write the values
|
||||
if (reverse) {
|
||||
if (inclusive) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - N_READS, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
out[axis_size-1] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + axis_size - offset - 1 - N_READS, offset + 1, axis_size);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (inclusive) {
|
||||
if ((offset + N_READS) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + offset);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + offset, offset, axis_size);
|
||||
}
|
||||
} else {
|
||||
if (lid == 0 && offset == 0) {
|
||||
out[0] = Op::init;
|
||||
}
|
||||
if ((offset + N_READS + 1) < axis_size) {
|
||||
write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
|
||||
} else {
|
||||
write_safe<U, N_READS, reverse>(values, out + offset + 1, offset + 1, axis_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Share the prefix
|
||||
if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
|
||||
simdgroup_sums[0] = values[N_READS-1];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
prefix = simdgroup_sums[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename Op, int N_READS, bool inclusive, bool reverse>
|
||||
[[kernel]] void strided_scan(
|
||||
const device T* in [[buffer(0)]],
|
||||
device U* out [[buffer(1)]],
|
||||
const constant size_t & axis_size [[buffer(2)]],
|
||||
const constant size_t & stride [[buffer(3)]],
|
||||
uint2 gid [[threadgroup_position_in_grid]],
|
||||
uint2 lid [[thread_position_in_threadgroup]],
|
||||
uint2 lsize [[threads_per_threadgroup]],
|
||||
uint simd_size [[threads_per_simdgroup]]) {
|
||||
Op op;
|
||||
|
||||
// Allocate memory
|
||||
threadgroup U read_buffer[N_READS*32*32 + N_READS*32];
|
||||
U values[N_READS];
|
||||
U prefix[N_READS];
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
prefix[i] = Op::init;
|
||||
}
|
||||
|
||||
// Compute offsets
|
||||
int offset = gid.y * axis_size * stride;
|
||||
int global_index_x = gid.x * lsize.y * N_READS;
|
||||
|
||||
for (uint j=0; j<axis_size; j+=simd_size) {
|
||||
// Calculate the indices for the current thread
|
||||
uint index_y = j + lid.y;
|
||||
uint check_index_y = index_y;
|
||||
uint index_x = global_index_x + lid.x * N_READS;
|
||||
if (reverse) {
|
||||
index_y = axis_size - 1 - index_y;
|
||||
}
|
||||
|
||||
// Read in SM
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = in[offset + index_y * stride + index_x + i];
|
||||
} else {
|
||||
read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i] = Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Read strided into registers
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i];
|
||||
}
|
||||
// Do we need the following barrier? Shouldn't all simd threads execute simultaneously?
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Perform the scan
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
values[i] = op.simd_scan(values[i]);
|
||||
values[i] = op(values[i], prefix[i]);
|
||||
prefix[i] = simd_shuffle(values[i], simd_size-1);
|
||||
}
|
||||
|
||||
// Write to SM
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
read_buffer[lid.x * simd_size * N_READS + lid.y * N_READS + i] = values[i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Write to device memory
|
||||
if (!inclusive) {
|
||||
if (check_index_y == 0) {
|
||||
if ((index_x + N_READS) < stride) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
if ((index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] = Op::init;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (reverse) {
|
||||
index_y -= 1;
|
||||
check_index_y += 1;
|
||||
} else {
|
||||
index_y += 1;
|
||||
check_index_y += 1;
|
||||
}
|
||||
}
|
||||
if (check_index_y < axis_size && (index_x + N_READS) < stride) {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
}
|
||||
} else {
|
||||
for (int i=0; i<N_READS; i++) {
|
||||
if (check_index_y < axis_size && (index_x + i) < stride) {
|
||||
out[offset + index_y * stride + index_x + i] = read_buffer[lid.y * simd_size * N_READS + lid.x * N_READS + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define instantiate_contiguous_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
||||
template [[host_name("contiguous_scan_" #name)]] \
|
||||
[[kernel]] void contiguous_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t & axis_size [[buffer(2)]], \
|
||||
uint gid [[thread_position_in_grid]], \
|
||||
uint lid [[thread_position_in_threadgroup]], \
|
||||
uint lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
||||
uint simd_group_id [[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
#define instantiate_strided_scan(name, itype, otype, op, inclusive, reverse, nreads) \
|
||||
template [[host_name("strided_scan_" #name)]] \
|
||||
[[kernel]] void strided_scan<itype, otype, op<otype>, nreads, inclusive, reverse>( \
|
||||
const device itype* in [[buffer(0)]], \
|
||||
device otype* out [[buffer(1)]], \
|
||||
const constant size_t & axis_size [[buffer(2)]], \
|
||||
const constant size_t & stride [[buffer(3)]], \
|
||||
uint2 gid [[thread_position_in_grid]], \
|
||||
uint2 lid [[thread_position_in_threadgroup]], \
|
||||
uint2 lsize [[threads_per_threadgroup]], \
|
||||
uint simd_size [[threads_per_simdgroup]]);
|
||||
|
||||
|
||||
#define instantiate_scan_helper(name, itype, otype, op, nreads) \
|
||||
instantiate_contiguous_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
instantiate_contiguous_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||
instantiate_contiguous_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||
instantiate_contiguous_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads) \
|
||||
instantiate_strided_scan(inclusive_##name, itype, otype, op, true, false, nreads) \
|
||||
instantiate_strided_scan(exclusive_##name, itype, otype, op, false, false, nreads) \
|
||||
instantiate_strided_scan(reverse_inclusive_##name, itype, otype, op, true, true, nreads) \
|
||||
instantiate_strided_scan(reverse_exclusive_##name, itype, otype, op, false, true, nreads)
|
||||
|
||||
instantiate_scan_helper(sum_bool__int32, bool, int32_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint8_uint8, uint8_t, uint8_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint16_uint16, uint16_t, uint16_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_uint32_uint32, uint32_t, uint32_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_uint64_uint64, uint64_t, uint64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_int8_int8, int8_t, int8_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_int16_int16, int16_t, int16_t, CumSum, 4)
|
||||
instantiate_scan_helper(sum_int32_int32, int32_t, int32_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_int64_int64, int64_t, int64_t, CumSum, 2)
|
||||
instantiate_scan_helper(sum_float16_float16, half, half, CumSum, 4)
|
||||
instantiate_scan_helper(sum_float32_float32, float, float, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumSum, 4)
|
||||
//instantiate_scan_helper(sum_complex64_complex64, complex64_t, complex64_t, CumSum)
|
||||
//instantiate_scan_helper(prod_bool__bool_, bool, bool, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint8_uint8, uint8_t, uint8_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint16_uint16, uint16_t, uint16_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_uint32_uint32, uint32_t, uint32_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_uint64_uint64, uint64_t, uint64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_int8_int8, int8_t, int8_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_int16_int16, int16_t, int16_t, CumProd, 4)
|
||||
instantiate_scan_helper(prod_int32_int32, int32_t, int32_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_int64_int64, int64_t, int64_t, CumProd, 2)
|
||||
instantiate_scan_helper(prod_float16_float16, half, half, CumProd, 4)
|
||||
instantiate_scan_helper(prod_float32_float32, float, float, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumProd, 4)
|
||||
//instantiate_scan_helper(prod_complex64_complex64, complex64_t, complex64_t, CumProd)
|
||||
//instantiate_scan_helper(max_bool__bool_, bool, bool, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint8_uint8, uint8_t, uint8_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint16_uint16, uint16_t, uint16_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_uint32_uint32, uint32_t, uint32_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_uint64_uint64, uint64_t, uint64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_int8_int8, int8_t, int8_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_int16_int16, int16_t, int16_t, CumMax, 4)
|
||||
instantiate_scan_helper(max_int32_int32, int32_t, int32_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_int64_int64, int64_t, int64_t, CumMax, 2)
|
||||
instantiate_scan_helper(max_float16_float16, half, half, CumMax, 4)
|
||||
instantiate_scan_helper(max_float32_float32, float, float, CumMax, 4)
|
||||
//instantiate_scan_helper(max_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMax, 4)
|
||||
//instantiate_scan_helper(max_complex64_complex64, complex64_t, complex64_t, CumMax)
|
||||
//instantiate_scan_helper(min_bool__bool_, bool, bool, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint8_uint8, uint8_t, uint8_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint16_uint16, uint16_t, uint16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_uint32_uint32, uint32_t, uint32_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_uint64_uint64, uint64_t, uint64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_int8_int8, int8_t, int8_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_int16_int16, int16_t, int16_t, CumMin, 4)
|
||||
instantiate_scan_helper(min_int32_int32, int32_t, int32_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_int64_int64, int64_t, int64_t, CumMin, 2)
|
||||
instantiate_scan_helper(min_float16_float16, half, half, CumMin, 4)
|
||||
instantiate_scan_helper(min_float32_float32, float, float, CumMin, 4)
|
||||
//instantiate_scan_helper(min_bfloat16_bfloat16, bfloat16_t, bfloat16_t, CumMin, 4)
|
||||
//instantiate_scan_helper(min_complex64_complex64, complex64_t, complex64_t, CumMin)
|
88
mlx/backend/metal/metal.cpp
Normal file
88
mlx/backend/metal/metal.cpp
Normal file
@ -0,0 +1,88 @@
|
||||
#include <cstdlib>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
int max_ops_per_buffer() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
||||
return atoi(buff_str);
|
||||
} else {
|
||||
return 10;
|
||||
}
|
||||
};
|
||||
static int max_ops_per_buffer_ = get_val();
|
||||
return max_ops_per_buffer_;
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
|
||||
MTL::CommandBuffer* increment_command_buffer(Stream s) {
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
if (command_buffer == nullptr ||
|
||||
d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
||||
if (command_buffer != nullptr) {
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s](MTL::CommandBuffer*) { scheduler::notify_task_completion(s); });
|
||||
d.commit_command_buffer(s.index);
|
||||
}
|
||||
command_buffer = d.new_command_buffer(s.index);
|
||||
}
|
||||
d.increment_command_buffer_ops(s.index);
|
||||
return command_buffer;
|
||||
}
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph) {
|
||||
auto task =
|
||||
[retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable {
|
||||
for (auto& d : deps) {
|
||||
d.wait();
|
||||
}
|
||||
auto s = arr.primitive().stream();
|
||||
auto command_buffer = increment_command_buffer(s);
|
||||
arr.primitive().eval_gpu(arr.inputs(), arr);
|
||||
if (p) {
|
||||
metal::device(s.device).end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr, p = std::move(p)](
|
||||
MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
p->set_value();
|
||||
// Signal this thread to clear the pool on a synchroniztion.
|
||||
scheduler::enqueue(s, []() {
|
||||
thread_autorelease_pool()->release();
|
||||
thread_autorelease_pool() =
|
||||
NS::AutoreleasePool::alloc()->init();
|
||||
});
|
||||
scheduler::notify_task_completion(s);
|
||||
});
|
||||
metal::device(s.device).commit_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[retain_graph, s, arr](MTL::CommandBuffer*) mutable {
|
||||
if (!retain_graph) {
|
||||
arr.detach();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
return task;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::metal
|
28
mlx/backend/metal/metal.h
Normal file
28
mlx/backend/metal/metal.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/stream.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
constexpr bool is_available() {
|
||||
#ifdef _METAL_
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
void new_stream(Stream stream);
|
||||
|
||||
std::function<void()> make_task(
|
||||
array& arr,
|
||||
std::vector<std::shared_future<void>> deps,
|
||||
std::shared_ptr<std::promise<void>> p,
|
||||
bool retain_graph);
|
||||
|
||||
} // namespace mlx::core::metal
|
82
mlx/backend/metal/softmax.cpp
Normal file
82
mlx/backend/metal/softmax.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/kernels/defines.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
if (!is_floating_point(out.dtype())) {
|
||||
throw std::runtime_error(
|
||||
"[softmax] Does not support non-floating point types.");
|
||||
}
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) {
|
||||
if (x.strides()[x.ndim() - 1] == 1) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
copies.push_back(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in = check_input(inputs[0]);
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
|
||||
const int simd_size = 32;
|
||||
const int n_reads = SOFTMAX_N_READS;
|
||||
const int looped_limit = SOFTMAX_LOOPED_LIMIT;
|
||||
std::string op_name = "softmax_";
|
||||
if (axis_size > looped_limit) {
|
||||
op_name += "looped_";
|
||||
}
|
||||
op_name += type_to_name(out);
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
{
|
||||
auto kernel = d.get_kernel(op_name);
|
||||
|
||||
MTL::Size grid_dims, group_dims;
|
||||
if (axis_size <= looped_limit) {
|
||||
size_t threadgroup_needed = (axis_size + n_reads - 1) / n_reads;
|
||||
size_t simds_needed = (threadgroup_needed + simd_size - 1) / simd_size;
|
||||
size_t threadgroup_size = simd_size * simds_needed;
|
||||
assert(threadgroup_size <= kernel->maxTotalThreadsPerThreadgroup());
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
} else {
|
||||
size_t threadgroup_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
size_t n_threads = n_rows * threadgroup_size;
|
||||
grid_dims = MTL::Size(n_threads, 1, 1);
|
||||
group_dims = MTL::Size(threadgroup_size, 1, 1);
|
||||
}
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 0);
|
||||
compute_encoder->setThreadgroupMemoryLength(simd_size * in.itemsize(), 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
336
mlx/backend/metal/sort.cpp
Normal file
336
mlx/backend/metal/sort.cpp
Normal file
@ -0,0 +1,336 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <bool ARGSORT>
|
||||
void single_block_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
int bn,
|
||||
int tn) {
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> nc_str = in.strides();
|
||||
nc_str.erase(nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
|
||||
int nc_dim = nc_shape.size();
|
||||
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
int stride_sorted_axis = in.strides()[axis];
|
||||
int stride_segment_axis = *std::min_element(nc_str.begin(), nc_str.end());
|
||||
|
||||
// Check if remaining strides are contiguous
|
||||
bool contiguous_write = true;
|
||||
if (axis != in.ndim() - 1 && axis != 0) {
|
||||
for (int i = 0; i < nc_str.size() - 1; ++i) {
|
||||
size_t expected = nc_str[i + 1] * nc_str[i + 1];
|
||||
contiguous_write &= (nc_str[i] == expected);
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare kernel name
|
||||
std::ostringstream kname;
|
||||
if (ARGSORT) {
|
||||
kname << "arg_";
|
||||
}
|
||||
kname << "block_merge_sort_" << type_to_name(in) << "_" << type_to_name(out)
|
||||
<< "_bn" << bn << "_tn" << tn;
|
||||
|
||||
if (!contiguous_write) {
|
||||
kname << "_nc";
|
||||
}
|
||||
|
||||
// Prepare command encoder
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Set inputs
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, out, 1);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 2);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 3);
|
||||
|
||||
if (contiguous_write) {
|
||||
compute_encoder->setBytes(&stride_segment_axis, sizeof(int), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 4);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 6);
|
||||
}
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
template <bool ARGSORT>
|
||||
void multi_block_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis,
|
||||
int bn,
|
||||
int tn,
|
||||
int n_blocks) {
|
||||
// Prepare shapes
|
||||
int n_rows = in.size() / in.shape(axis);
|
||||
|
||||
std::vector<size_t> nc_str = in.strides();
|
||||
nc_str.erase(nc_str.begin() + axis);
|
||||
|
||||
std::vector<int> nc_shape = in.shape();
|
||||
nc_shape.erase(nc_shape.begin() + axis);
|
||||
|
||||
int nc_dim = nc_shape.size();
|
||||
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
int stride_sorted_axis = in.strides()[axis];
|
||||
|
||||
// Make temporary copies
|
||||
array dev_vals_0({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
|
||||
array dev_vals_1({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
|
||||
|
||||
array dev_idxs_0({n_rows, size_sorted_axis}, uint32, nullptr, {});
|
||||
array dev_idxs_1({n_rows, size_sorted_axis}, uint32, nullptr, {});
|
||||
|
||||
array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
|
||||
|
||||
// Do allocations
|
||||
dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes()));
|
||||
dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes()));
|
||||
dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes()));
|
||||
dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes()));
|
||||
block_partitions.set_data(
|
||||
allocator::malloc_or_wait(block_partitions.nbytes()));
|
||||
|
||||
std::vector<array> copies = {
|
||||
dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions};
|
||||
|
||||
// Prepare command encoder
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
||||
// Do blockwise sort
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_sort_" << type_to_name(dev_vals_0) << "_"
|
||||
<< type_to_name(dev_idxs_0) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, in, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_0, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_0, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_sorted_axis, sizeof(int), 4);
|
||||
compute_encoder->setBytes(&nc_dim, sizeof(int), 5);
|
||||
compute_encoder->setBytes(nc_shape.data(), nc_dim * sizeof(int), 6);
|
||||
compute_encoder->setBytes(nc_str.data(), nc_dim * sizeof(size_t), 7);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merges
|
||||
bool ping = false;
|
||||
array dev_vals_in = dev_vals_0;
|
||||
array dev_idxs_in = dev_idxs_0;
|
||||
array dev_vals_out = dev_vals_1;
|
||||
array dev_idxs_out = dev_idxs_1;
|
||||
for (int merge_tiles = 2; merge_tiles <= n_blocks; merge_tiles *= 2) {
|
||||
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
|
||||
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
|
||||
dev_vals_out = ping ? dev_vals_0 : dev_vals_1;
|
||||
dev_idxs_out = ping ? dev_idxs_0 : dev_idxs_1;
|
||||
ping = !ping;
|
||||
|
||||
// Do partiton
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_partiton_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
// Do merge
|
||||
{
|
||||
std::ostringstream kname;
|
||||
kname << "mb_block_merge_" << type_to_name(dev_vals_in) << "_"
|
||||
<< type_to_name(dev_idxs_in) << "_bn" << bn << "_tn" << tn;
|
||||
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
set_array_buffer(compute_encoder, block_partitions, 0);
|
||||
set_array_buffer(compute_encoder, dev_vals_in, 1);
|
||||
set_array_buffer(compute_encoder, dev_idxs_in, 2);
|
||||
set_array_buffer(compute_encoder, dev_vals_out, 3);
|
||||
set_array_buffer(compute_encoder, dev_idxs_out, 4);
|
||||
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 5);
|
||||
compute_encoder->setBytes(&merge_tiles, sizeof(int), 6);
|
||||
compute_encoder->setBytes(&n_blocks, sizeof(int), 7);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(bn, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size(n_blocks, n_rows, 1);
|
||||
|
||||
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy outputs with appropriate strides
|
||||
array strided_out_arr = ARGSORT ? dev_idxs_out : dev_vals_out;
|
||||
|
||||
if (axis == strided_out_arr.ndim() - 1) {
|
||||
copy_gpu_inplace(strided_out_arr, out, CopyType::Vector, s);
|
||||
} else {
|
||||
std::vector<int> strided_out_shape = strided_out_arr.shape();
|
||||
std::vector<size_t> strided_out_str = strided_out_arr.strides();
|
||||
|
||||
int out_axis_shape = strided_out_shape[axis];
|
||||
int out_axis_str = strided_out_str[axis];
|
||||
|
||||
strided_out_shape.erase(strided_out_shape.begin() + axis);
|
||||
strided_out_str.erase(strided_out_str.begin() + axis);
|
||||
|
||||
strided_out_shape.push_back(out_axis_shape);
|
||||
strided_out_str.push_back(out_axis_str);
|
||||
|
||||
array strided_out_slice(strided_out_shape, out.dtype(), nullptr, {});
|
||||
strided_out_slice.copy_shared_buffer(
|
||||
strided_out_arr,
|
||||
strided_out_str,
|
||||
strided_out_arr.flags(),
|
||||
strided_out_arr.size(),
|
||||
0);
|
||||
|
||||
copy_gpu_inplace(strided_out_slice, out, CopyType::General, s);
|
||||
}
|
||||
|
||||
// Clear copies
|
||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||
}
|
||||
|
||||
template <bool ARGSORT>
|
||||
void gpu_merge_sort(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array& out,
|
||||
int axis_) {
|
||||
// Get size info
|
||||
int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
|
||||
int size_sorted_axis = in.shape(axis);
|
||||
|
||||
// Get kernel size
|
||||
int tn = 8;
|
||||
int bn = 128;
|
||||
int potential_bn = (size_sorted_axis + tn - 1) / tn;
|
||||
|
||||
if (potential_bn > 256) {
|
||||
bn = 512;
|
||||
} else if (potential_bn > 128) {
|
||||
bn = 256;
|
||||
} else {
|
||||
bn = 128;
|
||||
}
|
||||
|
||||
if (bn == 512 && size_of(in.dtype()) > 4) {
|
||||
bn = 256;
|
||||
}
|
||||
|
||||
int n_per_block = bn * tn;
|
||||
int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
|
||||
|
||||
if (n_blocks > 1) {
|
||||
return multi_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn, n_blocks);
|
||||
} else {
|
||||
return single_block_sort<ARGSORT>(s, d, in, out, axis, bn, tn);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<true>(s, d, in, out, axis_);
|
||||
}
|
||||
|
||||
void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<false>(s, d, in, out, axis_);
|
||||
}
|
||||
|
||||
void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We direct arg partition to sort for now
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<true>(s, d, in, out, axis_);
|
||||
}
|
||||
|
||||
void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// We direct partition to sort for now
|
||||
assert(inputs.size() == 1);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
gpu_merge_sort<false>(s, d, in, out, axis_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
7
mlx/backend/no_metal/CMakeLists.txt
Normal file
7
mlx/backend/no_metal/CMakeLists.txt
Normal file
@ -0,0 +1,7 @@
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
)
|
77
mlx/backend/no_metal/primitives.cpp
Normal file
77
mlx/backend/no_metal/primitives.cpp
Normal file
@ -0,0 +1,77 @@
|
||||
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#define NO_GPU(func) \
|
||||
void func::eval_gpu(const std::vector<array>& inputs, array& out) { \
|
||||
throw std::runtime_error(#func " has no GPU implementation."); \
|
||||
}
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
NO_GPU(Abs)
|
||||
NO_GPU(Add)
|
||||
NO_GPU(Arange)
|
||||
NO_GPU(ArcCos)
|
||||
NO_GPU(ArcCosh)
|
||||
NO_GPU(ArcSin)
|
||||
NO_GPU(ArcSinh)
|
||||
NO_GPU(ArcTan)
|
||||
NO_GPU(ArcTanh)
|
||||
NO_GPU(ArgPartition)
|
||||
NO_GPU(ArgReduce)
|
||||
NO_GPU(ArgSort)
|
||||
NO_GPU(AsType)
|
||||
NO_GPU(AsStrided)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(Concatenate)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(Copy)
|
||||
NO_GPU(Cos)
|
||||
NO_GPU(Cosh)
|
||||
NO_GPU(Divide)
|
||||
NO_GPU(Equal)
|
||||
NO_GPU(Erf)
|
||||
NO_GPU(ErfInv)
|
||||
NO_GPU(Exp)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Full)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(Greater)
|
||||
NO_GPU(GreaterEqual)
|
||||
NO_GPU(Less)
|
||||
NO_GPU(LessEqual)
|
||||
NO_GPU(Load)
|
||||
NO_GPU(Log)
|
||||
NO_GPU(Log1p)
|
||||
NO_GPU(LogicalNot)
|
||||
NO_GPU(LogAddExp)
|
||||
NO_GPU(Matmul)
|
||||
NO_GPU(Maximum)
|
||||
NO_GPU(Minimum)
|
||||
NO_GPU(Multiply)
|
||||
NO_GPU(Negative)
|
||||
NO_GPU(NotEqual)
|
||||
NO_GPU(Pad)
|
||||
NO_GPU(Partition)
|
||||
NO_GPU(Power)
|
||||
NO_GPU(RandomBits)
|
||||
NO_GPU(Reduce)
|
||||
NO_GPU(Reshape)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(Sigmoid)
|
||||
NO_GPU(Sign)
|
||||
NO_GPU(Sin)
|
||||
NO_GPU(Sinh)
|
||||
NO_GPU(Slice)
|
||||
NO_GPU(Softmax)
|
||||
NO_GPU(Sort)
|
||||
NO_GPU(Square)
|
||||
NO_GPU(Sqrt)
|
||||
NO_GPU(StopGradient)
|
||||
NO_GPU(Subtract)
|
||||
NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Transpose)
|
||||
|
||||
} // namespace mlx::core
|
29
mlx/device.cpp
Normal file
29
mlx/device.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
static Device default_device_{
|
||||
metal::is_available() ? Device::gpu : Device::cpu};
|
||||
|
||||
const Device& default_device() {
|
||||
return default_device_;
|
||||
}
|
||||
|
||||
void set_default_device(const Device& d) {
|
||||
if (!metal::is_available() && d == Device::gpu) {
|
||||
throw std::invalid_argument(
|
||||
"[set_default_device] Cannot set gpu device without gpu backend.");
|
||||
}
|
||||
default_device_ = d;
|
||||
}
|
||||
|
||||
bool operator==(const Device& lhs, const Device& rhs) {
|
||||
return lhs.type == rhs.type && lhs.index == rhs.index;
|
||||
}
|
||||
|
||||
bool operator!=(const Device& lhs, const Device& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
99
mlx/dtype.h
Normal file
99
mlx/dtype.h
Normal file
@ -0,0 +1,99 @@
|
||||
#pragma once
|
||||
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include "mlx/types/complex.h"
|
||||
#include "mlx/types/half_types.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct Dtype {
|
||||
enum class Val {
|
||||
bool_,
|
||||
uint8,
|
||||
uint16,
|
||||
uint32,
|
||||
uint64,
|
||||
int8,
|
||||
int16,
|
||||
int32,
|
||||
int64,
|
||||
float16,
|
||||
float32,
|
||||
bfloat16,
|
||||
complex64,
|
||||
};
|
||||
|
||||
enum class Kind {
|
||||
b, /* bool */
|
||||
u, /* unsigned int */
|
||||
i, /* signed int */
|
||||
f, /* float */
|
||||
c, /* complex */
|
||||
V, /* void - used for brain float */
|
||||
};
|
||||
|
||||
Val val;
|
||||
const uint8_t size;
|
||||
constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){};
|
||||
constexpr operator Val() const {
|
||||
return val;
|
||||
};
|
||||
};
|
||||
|
||||
inline bool is_available(const Dtype& dtype) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
|
||||
|
||||
static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
|
||||
static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
|
||||
static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
|
||||
static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
|
||||
|
||||
static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
|
||||
static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
|
||||
static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
|
||||
static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
|
||||
|
||||
static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
|
||||
static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
|
||||
static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
|
||||
static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
|
||||
|
||||
Dtype promote_types(const Dtype& t1, const Dtype& t2);
|
||||
|
||||
inline uint8_t size_of(const Dtype& t) {
|
||||
return t.size;
|
||||
}
|
||||
|
||||
Dtype::Kind kindof(const Dtype& t);
|
||||
|
||||
inline bool is_unsigned(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b;
|
||||
}
|
||||
|
||||
inline bool is_floating_point(const Dtype& t) {
|
||||
return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V ||
|
||||
kindof(t) == Dtype::Kind::c;
|
||||
}
|
||||
|
||||
inline bool is_integral(const Dtype& t) {
|
||||
return !(is_floating_point(t));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TypeToDtype {
|
||||
operator Dtype();
|
||||
};
|
||||
|
||||
// Array protocol typestring for Dtype
|
||||
std::string dtype_to_array_protocol(const Dtype& t);
|
||||
// Dtype from array protocol type string
|
||||
Dtype dtype_from_array_protocol(const std::string& t);
|
||||
|
||||
} // namespace mlx::core
|
190
mlx/fft.cpp
Normal file
190
mlx/fft.cpp
Normal file
@ -0,0 +1,190 @@
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::fft {
|
||||
|
||||
array fft_impl(
|
||||
const array& a,
|
||||
std::vector<int> n,
|
||||
const std::vector<int>& axes,
|
||||
bool real,
|
||||
bool inverse,
|
||||
StreamOrDevice s) {
|
||||
if (a.ndim() < 1) {
|
||||
throw std::invalid_argument(
|
||||
"[fftn] Requires array with at least one dimension.");
|
||||
}
|
||||
if (n.size() != axes.size()) {
|
||||
throw std::invalid_argument("[fftn] Shape and axes have different sizes.");
|
||||
}
|
||||
if (axes.empty()) {
|
||||
return a;
|
||||
}
|
||||
|
||||
std::vector<size_t> valid_axes;
|
||||
for (int ax : axes) {
|
||||
valid_axes.push_back(ax < 0 ? ax + a.ndim() : ax);
|
||||
}
|
||||
std::set<int> unique_axes(valid_axes.begin(), valid_axes.end());
|
||||
if (unique_axes.size() != axes.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[fftn] Duplicated axis received " << axes;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (*unique_axes.begin() < 0 || *unique_axes.rbegin() >= a.ndim()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[fftn] Invalid axis received for array with " << a.ndim()
|
||||
<< " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// In the following shape manipulations there are three cases to consdier:
|
||||
// 1. In a complex to complex transform (fftn / ifftn) the output
|
||||
// and input shapes are the same.
|
||||
// 2. In a real to complex transform (rfftn) n specifies the input dims
|
||||
// and the output dims are n[i] / 2 + 1
|
||||
// 3 In a complex to real transform (irfftn) n specifies the output dims
|
||||
// and the input dims are n[i] / 2 + 1
|
||||
|
||||
if (std::any_of(n.begin(), n.end(), [](auto i) { return i <= 0; })) {
|
||||
std::ostringstream msg;
|
||||
msg << "[fftn] Invalid FFT output size requested " << n;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
std::vector<int> in_shape = a.shape();
|
||||
for (int i = 0; i < valid_axes.size(); ++i) {
|
||||
in_shape[valid_axes[i]] = n[i];
|
||||
}
|
||||
if (real && inverse) {
|
||||
in_shape[valid_axes.back()] = n.back() / 2 + 1;
|
||||
}
|
||||
|
||||
bool any_greater = false;
|
||||
bool any_less = false;
|
||||
for (int i = 0; i < in_shape.size(); ++i) {
|
||||
any_greater |= in_shape[i] > a.shape()[i];
|
||||
any_less |= in_shape[i] < a.shape()[i];
|
||||
}
|
||||
|
||||
auto in = a;
|
||||
if (any_less) {
|
||||
in = slice(in, std::vector<int>(in.ndim(), 0), in_shape, s);
|
||||
}
|
||||
if (any_greater) {
|
||||
// Pad with zeros
|
||||
auto tmp = zeros(in_shape, a.dtype(), s);
|
||||
in = scatter(tmp, std::vector<array>{}, in, std::vector<int>{}, s);
|
||||
}
|
||||
|
||||
auto out_shape = in_shape;
|
||||
if (real) {
|
||||
auto ax = valid_axes.back();
|
||||
out_shape[ax] = inverse ? n.back() : out_shape[ax] / 2 + 1;
|
||||
}
|
||||
|
||||
auto in_type = real && !inverse ? float32 : complex64;
|
||||
auto out_type = real && inverse ? float32 : complex64;
|
||||
return array(
|
||||
out_shape,
|
||||
out_type,
|
||||
std::make_unique<FFT>(to_stream(s), valid_axes, inverse, real),
|
||||
{astype(in, in_type, s)});
|
||||
}
|
||||
|
||||
array fft_impl(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool real,
|
||||
bool inverse,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> n;
|
||||
for (auto ax : axes) {
|
||||
n.push_back(a.shape(ax));
|
||||
}
|
||||
if (real && inverse) {
|
||||
n.back() = (n.back() - 1) * 2;
|
||||
}
|
||||
return fft_impl(a, n, axes, real, inverse, s);
|
||||
}
|
||||
|
||||
array fft_impl(const array& a, bool real, bool inverse, StreamOrDevice s) {
|
||||
std::vector<int> axes(a.ndim());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
return fft_impl(a, axes, real, inverse, s);
|
||||
}
|
||||
|
||||
array fftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, n, axes, false, false, s);
|
||||
}
|
||||
array fftn(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, axes, false, false, s);
|
||||
}
|
||||
array fftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, false, false, s);
|
||||
}
|
||||
|
||||
array ifftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, n, axes, false, true, s);
|
||||
}
|
||||
array ifftn(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, axes, false, true, s);
|
||||
}
|
||||
array ifftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, false, true, s);
|
||||
}
|
||||
|
||||
array rfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, n, axes, true, false, s);
|
||||
}
|
||||
array rfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, axes, true, false, s);
|
||||
}
|
||||
array rfftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, true, false, s);
|
||||
}
|
||||
|
||||
array irfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& n,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, n, axes, true, true, s);
|
||||
}
|
||||
array irfftn(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, axes, true, true, s);
|
||||
}
|
||||
array irfftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return fft_impl(a, true, true, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fft
|
21
mlx/graph_utils.h
Normal file
21
mlx/graph_utils.h
Normal file
@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
void print_graph(std::ostream& os, Arrays... outputs) {
|
||||
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
|
||||
|
||||
template <typename... Arrays>
|
||||
void export_to_dot(std::ostream& os, Arrays... outputs) {
|
||||
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
2323
mlx/ops.cpp
Normal file
2323
mlx/ops.cpp
Normal file
File diff suppressed because it is too large
Load Diff
16
mlx/transforms_impl.h
Normal file
16
mlx/transforms_impl.h
Normal file
@ -0,0 +1,16 @@
|
||||
|
||||
namespace mlx::core::detail {
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vmap_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& in_axes);
|
||||
|
||||
std::vector<array> vmap_replace(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<array>& s_inputs,
|
||||
const std::vector<array>& s_outputs,
|
||||
const std::vector<int>& in_axes,
|
||||
const std::vector<int>& out_axes);
|
||||
|
||||
} // namespace mlx::core::detail
|
75
mlx/types/complex.h
Normal file
75
mlx/types/complex.h
Normal file
@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
#include <complex>
|
||||
#include "mlx/types/half_types.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
struct complex64_t;
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool can_convert_to_complex64 =
|
||||
!std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
|
||||
|
||||
struct complex64_t : public std::complex<float> {
|
||||
complex64_t(float v, float u) : std::complex<float>(v, u){};
|
||||
complex64_t(std::complex<float> v) : std::complex<float>(v){};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = typename std::enable_if<can_convert_to_complex64<T>>::type>
|
||||
complex64_t(T x) : std::complex<float>(x){};
|
||||
|
||||
operator float() const {
|
||||
return real();
|
||||
};
|
||||
};
|
||||
|
||||
inline bool operator>=(const complex64_t& a, const complex64_t& b) {
|
||||
return (a.real() > b.real()) ||
|
||||
(a.real() == b.real() && a.imag() >= b.imag());
|
||||
}
|
||||
|
||||
inline bool operator>(const complex64_t& a, const complex64_t& b) {
|
||||
return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
|
||||
}
|
||||
|
||||
inline bool operator<=(const complex64_t& a, const complex64_t& b) {
|
||||
return operator>=(b, a);
|
||||
}
|
||||
|
||||
inline bool operator<(const complex64_t& a, const complex64_t& b) {
|
||||
return operator>(b, a);
|
||||
}
|
||||
|
||||
inline complex64_t operator-(const complex64_t& v) {
|
||||
return -static_cast<std::complex<float>>(v);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#define complex_binop_helper(_op_, _operator_, itype) \
|
||||
inline complex64_t _operator_(itype x, const complex64_t& y) { \
|
||||
return x _op_ static_cast<std::complex<float>>(y); \
|
||||
} \
|
||||
inline complex64_t _operator_(const complex64_t& x, itype y) { \
|
||||
return static_cast<std::complex<float>>(x) _op_ y; \
|
||||
}
|
||||
|
||||
#define complex_binop(_op_, _operator_) \
|
||||
inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
|
||||
return static_cast<std::complex<float>>(x) \
|
||||
_op_ static_cast<std::complex<float>>(y); \
|
||||
} \
|
||||
complex_binop_helper(_op_, _operator_, bool) \
|
||||
complex_binop_helper(_op_, _operator_, uint32_t) \
|
||||
complex_binop_helper(_op_, _operator_, uint64_t) \
|
||||
complex_binop_helper(_op_, _operator_, int32_t) \
|
||||
complex_binop_helper(_op_, _operator_, int64_t) \
|
||||
complex_binop_helper(_op_, _operator_, float16_t) \
|
||||
complex_binop_helper(_op_, _operator_, bfloat16_t) \
|
||||
complex_binop_helper(_op_, _operator_, const std::complex<float>&) \
|
||||
complex_binop_helper(_op_, _operator_, float)
|
||||
// clang-format on
|
||||
|
||||
complex_binop(+, operator+)
|
||||
|
||||
} // namespace mlx::core
|
232
mlx/types/fp16.h
Normal file
232
mlx/types/fp16.h
Normal file
@ -0,0 +1,232 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#define __MLX_HALF_NAN__ 0x7D00
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
union float_bits_fp16 {
|
||||
float f;
|
||||
uint32_t u;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
struct _MLX_Float16 {
|
||||
uint16_t bits_;
|
||||
|
||||
// Default constructor
|
||||
_MLX_Float16() = default;
|
||||
|
||||
// Default copy constructor
|
||||
_MLX_Float16(_MLX_Float16 const&) = default;
|
||||
|
||||
// Appease std::vector<bool> for being special
|
||||
_MLX_Float16& operator=(std::vector<bool>::reference x) {
|
||||
bits_ = x;
|
||||
return *this;
|
||||
}
|
||||
|
||||
_MLX_Float16& operator=(const float& x) {
|
||||
return (*this = _MLX_Float16(x));
|
||||
}
|
||||
|
||||
// From float32
|
||||
_MLX_Float16(const float& x) : bits_(0) {
|
||||
// Conversion following
|
||||
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
||||
|
||||
// Union
|
||||
float_bits_fp16 in;
|
||||
|
||||
// Take fp32 bits
|
||||
in.f = x;
|
||||
|
||||
// Find and take sign bit
|
||||
uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
|
||||
uint16_t x_sign_16 = (x_sign_32 >> 16);
|
||||
|
||||
if (std::isnan(x)) {
|
||||
bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
|
||||
} else {
|
||||
// Union
|
||||
float_bits_fp16 inf_scale, zero_scale, magic_bits;
|
||||
|
||||
// Find exponent bits and take the max supported by half
|
||||
uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
|
||||
uint32_t max_expo_32 = uint32_t(0x38800000);
|
||||
x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
|
||||
x_expo_32 += uint32_t(15) << 23;
|
||||
|
||||
// Handle scaling to inf as needed
|
||||
inf_scale.u = uint32_t(0x77800000);
|
||||
zero_scale.u = uint32_t(0x08800000);
|
||||
|
||||
// Combine with magic and let addition do rouding
|
||||
magic_bits.u = x_expo_32;
|
||||
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
|
||||
|
||||
// Take the lower 5 bits of the exponent
|
||||
uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
|
||||
|
||||
// Collect the lower 12 bits which have the mantissa
|
||||
uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
|
||||
|
||||
// Combine sign, exp and mantissa
|
||||
bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
|
||||
}
|
||||
}
|
||||
|
||||
// To float32
|
||||
operator float() const {
|
||||
// Conversion following
|
||||
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
||||
|
||||
// Union
|
||||
float_bits_fp16 out;
|
||||
|
||||
uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
|
||||
uint32_t base = (bits_ << 16);
|
||||
uint32_t two_base = base + base;
|
||||
|
||||
uint32_t denorm_max = 1u << 27;
|
||||
if (two_base < denorm_max) {
|
||||
out.u = uint32_t(126) << 23; // magic mask
|
||||
out.u |= (two_base >> 17); // Bits from fp16
|
||||
out.f -= 0.5f; // magic bias
|
||||
} else {
|
||||
out.u = uint32_t(0xE0) << 23; // exponent offset
|
||||
out.u += (two_base >> 4); // Bits from fp16
|
||||
float out_unscaled = out.f; // Store value
|
||||
out.u = uint32_t(0x7800000); // exponent scale
|
||||
out.f *= out_unscaled;
|
||||
}
|
||||
|
||||
// Add sign
|
||||
out.u |= x_sign_32;
|
||||
|
||||
return out.f;
|
||||
}
|
||||
};
|
||||
|
||||
#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
||||
inline otype __operator__(atype lhs, btype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
||||
inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
} \
|
||||
inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
|
||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
||||
}
|
||||
|
||||
// Operators
|
||||
#define half_binop(__op__, __operator__) \
|
||||
half_binop_base( \
|
||||
__op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
|
||||
half_binop_helper(__op__, __operator__, float, float, float); \
|
||||
half_binop_helper(__op__, __operator__, double, double, double); \
|
||||
half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
|
||||
half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
|
||||
half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
|
||||
half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
|
||||
half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
|
||||
|
||||
half_binop(+, operator+);
|
||||
half_binop(-, operator-);
|
||||
half_binop(*, operator*);
|
||||
half_binop(/, operator/);
|
||||
|
||||
#undef half_binop
|
||||
|
||||
// Comparison ops
|
||||
#define half_compop(__op__, __operator__) \
|
||||
half_binop_base( \
|
||||
__op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
|
||||
half_binop_helper(__op__, __operator__, bool, float, float); \
|
||||
half_binop_helper(__op__, __operator__, bool, double, double); \
|
||||
half_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
||||
half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
||||
half_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
||||
half_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
||||
|
||||
half_compop(>, operator>);
|
||||
half_compop(<, operator<);
|
||||
half_compop(>=, operator>=);
|
||||
half_compop(<=, operator<=);
|
||||
half_compop(==, operator==);
|
||||
half_compop(!=, operator!=);
|
||||
|
||||
#undef half_compop
|
||||
|
||||
// Negative
|
||||
inline _MLX_Float16 operator-(_MLX_Float16 lhs) {
|
||||
return -static_cast<float>(lhs);
|
||||
}
|
||||
|
||||
// Inplace ops
|
||||
#define half_inplace_op(__op__, __operator__) \
|
||||
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
|
||||
lhs = lhs __op__ rhs; \
|
||||
return lhs; \
|
||||
} \
|
||||
inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
|
||||
lhs = lhs __op__ rhs; \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
half_inplace_op(+, operator+=);
|
||||
half_inplace_op(-, operator-=);
|
||||
half_inplace_op(*, operator*=);
|
||||
half_inplace_op(/, operator/=);
|
||||
|
||||
#undef half_inplace_op
|
||||
|
||||
// Bitwise ops
|
||||
|
||||
#define half_bitop(__op__, __operator__) \
|
||||
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
|
||||
_MLX_Float16 out; \
|
||||
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
||||
return out; \
|
||||
} \
|
||||
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
|
||||
_MLX_Float16 out; \
|
||||
out.bits_ = lhs.bits_ __op__ rhs; \
|
||||
return out; \
|
||||
} \
|
||||
inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
|
||||
_MLX_Float16 out; \
|
||||
out.bits_ = lhs __op__ rhs.bits_; \
|
||||
return out; \
|
||||
}
|
||||
|
||||
half_bitop(|, operator|);
|
||||
half_bitop(&, operator&);
|
||||
half_bitop(^, operator^);
|
||||
|
||||
#undef half_bitop
|
||||
|
||||
#define half_inplace_bitop(__op__, __operator__) \
|
||||
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
|
||||
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
||||
return lhs; \
|
||||
} \
|
||||
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
|
||||
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
||||
return lhs; \
|
||||
}
|
||||
|
||||
half_inplace_bitop(|, operator|=);
|
||||
half_inplace_bitop(&, operator&=);
|
||||
half_inplace_bitop(^, operator^=);
|
||||
|
||||
#undef half_inplace_bitop
|
||||
|
||||
} // namespace mlx::core
|
3
pyproject.toml
Normal file
3
pyproject.toml
Normal file
@ -0,0 +1,3 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "pybind11>=2.10", "cmake>=3.24"]
|
||||
build-backend = "setuptools.build_meta"
|
124
python/mlx/nn/layers/convolution.py
Normal file
124
python/mlx/nn/layers/convolution.py
Normal file
@ -0,0 +1,124 @@
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Conv1d(Module):
|
||||
"""Applies a 1-dimensional convolution over the multi-channel input sequence.
|
||||
|
||||
The channels are expected to be last i.e. the input shape should be ``NLC`` where:
|
||||
- ``N`` is the batch dimension
|
||||
- ``L`` is the sequence length
|
||||
- ``C`` is the number of input channels
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels
|
||||
out_channels (int): The number of output channels
|
||||
kernel_size (int): The size of the convolution filters
|
||||
stride (int, optional): The stride when applying the filter.
|
||||
Default: 1.
|
||||
padding (int, optional): How many positions to 0-pad the input with.
|
||||
Default: 0.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
||||
Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
scale = math.sqrt(1 / (in_channels * kernel_size))
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, kernel_size, in_channels),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv1d(x, self.weight, self.stride, self.padding)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
||||
|
||||
class Conv2d(Module):
|
||||
"""Applies a 2-dimensional convolution over the multi-channel input image.
|
||||
|
||||
The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
|
||||
- ``N`` is the batch dimension
|
||||
- ``H`` is the input image height
|
||||
- ``W`` is the input image width
|
||||
- ``C`` is the number of input channels
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channels.
|
||||
out_channels (int): The number of output channels.
|
||||
kernel_size (int or tuple): The size of the convolution filters.
|
||||
stride (int or tuple, optional): The size of the stride when
|
||||
applying the filter. Default: 0.
|
||||
padding (int or tuple, optional): How many positions to 0-pad
|
||||
the input with. Default: 0.
|
||||
bias (bool, optional): If ``True`` add a learnable bias to the
|
||||
output. Default: ``True``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, tuple],
|
||||
stride: Union[int, tuple] = 1,
|
||||
padding: Union[int, tuple] = 0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
kernel_size, stride, padding = map(
|
||||
lambda x: (x, x) if isinstance(x, int) else x,
|
||||
(kernel_size, stride, padding),
|
||||
)
|
||||
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, *kernel_size, in_channels),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
||||
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
||||
f"padding={self.padding}, bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
y = mx.conv2d(x, self.weight, self.stride, self.padding)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
28
python/mlx/nn/layers/embedding.py
Normal file
28
python/mlx/nn/layers/embedding.py
Normal file
@ -0,0 +1,28 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Embedding(Module):
|
||||
"""Implements a simple lookup table that maps each input integer to a
|
||||
high-dimensional vector.
|
||||
|
||||
Typically used to embed discrete tokens for processing by neural networks.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): How many possible discrete tokens can we embed.
|
||||
Usually called the vocabulary size.
|
||||
dims (int): The dimensionality of the embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, dims: int):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / dims)
|
||||
self.weight = mx.random.normal((num_embeddings, dims)) * scale
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
|
||||
|
||||
def __call__(self, x):
|
||||
return self.weight[x]
|
34
python/mlx/nn/layers/linear.py
Normal file
34
python/mlx/nn/layers/linear.py
Normal file
@ -0,0 +1,34 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
class Linear(Module):
|
||||
"""Applies an affine transformation to the input.
|
||||
|
||||
Args:
|
||||
input_dims (int): The dimensionality of the input features
|
||||
output_dims (int): The dimensionality of the output features
|
||||
bias (bool): If set to False then the layer will not use a bias
|
||||
"""
|
||||
|
||||
def __init__(self, input_dims: int, output_dims: int, bias: bool = True):
|
||||
super().__init__()
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims, input_dims),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.zeros((output_dims,))
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
x = x @ self.weight.T
|
||||
if "bias" in self:
|
||||
x = x + self.bias
|
||||
return x
|
19
python/src/load.h
Normal file
19
python/src/load.h
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <unordered_map>
|
||||
#include <variant>
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlx::core;
|
||||
|
||||
using DictOrArray = std::variant<array, std::unordered_map<std::string, array>>;
|
||||
|
||||
DictOrArray mlx_load_helper(py::object file, StreamOrDevice s);
|
||||
void mlx_save_helper(py::object file, array a, bool retain_graph = true);
|
||||
void mlx_savez_helper(
|
||||
py::object file,
|
||||
py::args args,
|
||||
const py::kwargs& kwargs,
|
||||
bool compressed = false);
|
31
python/src/mlx.cpp
Normal file
31
python/src/mlx.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#define STRINGIFY(x) #x
|
||||
#define TOSTRING(x) STRINGIFY(x)
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void init_array(py::module_&);
|
||||
void init_device(py::module_&);
|
||||
void init_stream(py::module_&);
|
||||
void init_metal(py::module_&);
|
||||
void init_ops(py::module_&);
|
||||
void init_transforms(py::module_&);
|
||||
void init_random(py::module_&);
|
||||
void init_fft(py::module_&);
|
||||
|
||||
PYBIND11_MODULE(core, m) {
|
||||
m.doc() = "mlx: A framework for machine learning on Apple Silicon.";
|
||||
|
||||
auto reprlib_fix = py::module_::import("mlx._reprlib_fix");
|
||||
|
||||
init_device(m);
|
||||
init_stream(m);
|
||||
init_array(m);
|
||||
init_metal(m);
|
||||
init_ops(m);
|
||||
init_transforms(m);
|
||||
init_random(m);
|
||||
init_fft(m);
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
}
|
723
python/src/transforms.cpp
Normal file
723
python/src/transforms.cpp
Normal file
@ -0,0 +1,723 @@
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<int, std::vector<int>>;
|
||||
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
|
||||
std::vector<T> vals;
|
||||
if (auto pv = std::get_if<T>(&v); pv) {
|
||||
vals.push_back(*pv);
|
||||
} else {
|
||||
vals = std::get<std::vector<T>>(v);
|
||||
}
|
||||
return vals;
|
||||
}
|
||||
|
||||
void tree_visit(py::object tree, std::function<void(py::handle)> visitor) {
|
||||
std::function<void(py::handle)> recurse;
|
||||
recurse = [&](py::handle subtree) {
|
||||
if (py::isinstance<py::list>(subtree) ||
|
||||
py::isinstance<py::tuple>(subtree)) {
|
||||
for (auto item : subtree) {
|
||||
recurse(item);
|
||||
}
|
||||
} else if (py::isinstance<py::dict>(subtree)) {
|
||||
for (auto item : py::cast<py::dict>(subtree)) {
|
||||
recurse(item.second);
|
||||
}
|
||||
} else {
|
||||
visitor(subtree);
|
||||
}
|
||||
};
|
||||
|
||||
recurse(tree);
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void validate_subtrees(const std::vector<py::object>& subtrees) {
|
||||
int len = py::cast<T>(subtrees[0]).size();
|
||||
for (auto& subtree : subtrees) {
|
||||
if ((py::isinstance<T>(subtree) && py::cast<T>(subtree).size() != len) ||
|
||||
py::isinstance<U>(subtree) || py::isinstance<V>(subtree)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Additional input tree is not a valid prefix of the first tree.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
const std::vector<py::object>& trees,
|
||||
std::function<py::object(const std::vector<py::object>&)> transform) {
|
||||
std::function<py::object(const std::vector<py::object>&)> recurse;
|
||||
|
||||
recurse = [&](const std::vector<py::object>& subtrees) {
|
||||
if (py::isinstance<py::list>(subtrees[0])) {
|
||||
py::list l;
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::list, py::tuple, py::dict>(subtrees);
|
||||
for (int i = 0; i < py::cast<py::list>(subtrees[0]).size(); ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::list>(subtrees[j])) {
|
||||
items[j] = py::cast<py::list>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l.append(recurse(items));
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::tuple>(subtrees[0])) {
|
||||
// Check the rest of the subtrees
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
int len = py::cast<py::tuple>(subtrees[0]).size();
|
||||
py::tuple l(len);
|
||||
validate_subtrees<py::tuple, py::list, py::dict>(subtrees);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::tuple>(subtrees[j])) {
|
||||
items[j] = py::cast<py::tuple>(subtrees[j])[i];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
l[i] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(l);
|
||||
} else if (py::isinstance<py::dict>(subtrees[0])) {
|
||||
std::vector<py::object> items(subtrees.size());
|
||||
validate_subtrees<py::dict, py::list, py::tuple>(subtrees);
|
||||
py::dict d;
|
||||
for (auto item : py::cast<py::dict>(subtrees[0])) {
|
||||
for (int j = 0; j < subtrees.size(); ++j) {
|
||||
if (py::isinstance<py::dict>(subtrees[j])) {
|
||||
auto subdict = py::cast<py::dict>(subtrees[j]);
|
||||
if (!subdict.contains(item.first)) {
|
||||
throw std::invalid_argument(
|
||||
"[tree_map] Tree is not a valid prefix tree of the first tree.");
|
||||
}
|
||||
items[j] = subdict[item.first];
|
||||
} else {
|
||||
items[j] = subtrees[j];
|
||||
}
|
||||
}
|
||||
d[item.first] = recurse(items);
|
||||
}
|
||||
return py::cast<py::object>(d);
|
||||
} else {
|
||||
return transform(subtrees);
|
||||
}
|
||||
};
|
||||
return recurse(trees);
|
||||
}
|
||||
|
||||
py::object tree_map(
|
||||
py::object tree,
|
||||
std::function<py::object(py::handle)> transform) {
|
||||
return tree_map({tree}, [&](std::vector<py::object> inputs) {
|
||||
return transform(inputs[0]);
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<array> tree_flatten(py::object tree, bool strict = true) {
|
||||
std::vector<array> flat_tree;
|
||||
|
||||
tree_visit(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
flat_tree.push_back(py::cast<array>(obj));
|
||||
} else if (strict) {
|
||||
throw std::invalid_argument("Argument is not an array");
|
||||
}
|
||||
});
|
||||
|
||||
return flat_tree;
|
||||
}
|
||||
|
||||
py::object tree_unflatten(
|
||||
py::object tree,
|
||||
const std::vector<array>& values,
|
||||
int index = 0) {
|
||||
return tree_map(tree, [&](py::handle obj) {
|
||||
if (py::isinstance<array>(obj)) {
|
||||
return py::cast(values[index++]);
|
||||
} else {
|
||||
return py::cast<py::object>(obj);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
auto validate_argnums_argnames(
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto vec_names = to_vector(argnames);
|
||||
|
||||
if (!argnums.has_value()) {
|
||||
// argnums was not provided and argnames was empty
|
||||
if (vec_names.empty()) {
|
||||
return std::make_pair(std::vector<int>{0}, vec_names);
|
||||
} else {
|
||||
return std::make_pair(std::vector<int>{}, vec_names);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(to_vector(*argnums), vec_names);
|
||||
}
|
||||
|
||||
auto py_value_and_grad(
|
||||
const py::function& fun,
|
||||
std::vector<int> argnums,
|
||||
std::vector<std::string> argnames,
|
||||
const std::string& error_msg_tag,
|
||||
bool scalar_func_only) {
|
||||
// Sanitize argnums
|
||||
if (argnums.size() == 0 && argnames.size() == 0) {
|
||||
throw std::invalid_argument(
|
||||
error_msg_tag + " Gradient wrt no argument requested");
|
||||
}
|
||||
if (argnums.size() > 0) {
|
||||
std::sort(argnums.begin(), argnums.end());
|
||||
if (argnums[0] < 0) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag
|
||||
<< " Can't compute the gradient of negative argument index "
|
||||
<< argnums[0];
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
|
||||
const py::args& args, const py::kwargs& kwargs) {
|
||||
// Sanitize the input
|
||||
if (argnums.size() > 0 && argnums.back() >= args.size()) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " Can't compute the gradient of argument index "
|
||||
<< argnums.back() << " because the function is called with only "
|
||||
<< args.size() << " arguments.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
for (auto& key : argnames) {
|
||||
if (!kwargs.contains(key)) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag
|
||||
<< " Can't compute the gradient of keyword argument '" << key
|
||||
<< "' because the function is called with the "
|
||||
<< "following keyword arguments {";
|
||||
for (auto item : kwargs) {
|
||||
msg << item.first.cast<std::string>() << ",";
|
||||
}
|
||||
msg << "}";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
// Collect the arrays
|
||||
std::vector<array> arrays;
|
||||
std::vector<int> counts(1, 0);
|
||||
for (auto i : argnums) {
|
||||
auto argsi = tree_flatten(args[i]);
|
||||
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
|
||||
counts.push_back(argsi.size());
|
||||
}
|
||||
for (auto& key : argnames) {
|
||||
auto argsk = tree_flatten(kwargs[key.c_str()]);
|
||||
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
|
||||
counts.push_back(argsk.size());
|
||||
}
|
||||
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
|
||||
std::vector<int> gradient_indices(arrays.size());
|
||||
std::iota(gradient_indices.begin(), gradient_indices.end(), 0);
|
||||
|
||||
// value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
py::object py_value_out;
|
||||
auto value_and_grads = value_and_grad(
|
||||
[&fun,
|
||||
&args,
|
||||
&kwargs,
|
||||
&argnums,
|
||||
&argnames,
|
||||
&counts,
|
||||
&py_value_out,
|
||||
&error_msg_tag,
|
||||
scalar_func_only](const std::vector<array>& a) {
|
||||
// Copy the arguments
|
||||
py::args args_cpy = py::tuple(args.size());
|
||||
py::kwargs kwargs_cpy = py::kwargs();
|
||||
int j = 0;
|
||||
for (int i = 0; i < args.size(); ++i) {
|
||||
if (j < argnums.size() && i == argnums[j]) {
|
||||
args_cpy[i] = tree_unflatten(args[i], a, counts[j]);
|
||||
j++;
|
||||
} else {
|
||||
args_cpy[i] = args[i];
|
||||
}
|
||||
}
|
||||
for (auto& key : argnames) {
|
||||
kwargs_cpy[key.c_str()] =
|
||||
tree_unflatten(kwargs[key.c_str()], a, counts[j]);
|
||||
j++;
|
||||
}
|
||||
for (auto item : kwargs) {
|
||||
if (kwargs_cpy.contains(item.first)) {
|
||||
continue;
|
||||
}
|
||||
kwargs_cpy[item.first] = item.second;
|
||||
}
|
||||
|
||||
// Call the python function
|
||||
py_value_out = fun(*args_cpy, **kwargs_cpy);
|
||||
|
||||
// Validate the return value of the python function
|
||||
if (!py::isinstance<array>(py_value_out)) {
|
||||
if (scalar_func_only) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be a "
|
||||
<< "scalar array; but " << py_value_out.get_type()
|
||||
<< " was returned.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!py::isinstance<py::tuple>(py_value_out)) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but "
|
||||
<< py_value_out.get_type() << " was returned.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
py::tuple ret = py::cast<py::tuple>(py_value_out);
|
||||
if (ret.size() == 0) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a non-empty tuple. The first value should be a "
|
||||
<< "scalar array and the rest can be anything. Instead, "
|
||||
<< "we got an empty tuple.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (!py::isinstance<array>(ret[0])) {
|
||||
std::ostringstream msg;
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
|
||||
<< "was a tuple with the first value being of type "
|
||||
<< ret[0].get_type() << " .";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
return tree_flatten(py_value_out, false);
|
||||
},
|
||||
gradient_indices)(arrays);
|
||||
|
||||
auto value = value_and_grads.first;
|
||||
auto gradients = value_and_grads.second;
|
||||
|
||||
// Put the gradients back in their container.
|
||||
// We have the following cases:
|
||||
//
|
||||
// 1. Single python positional argument has a gradient (eg argnums=[0])
|
||||
// 2. Many python positional arguments have gradients (eg argnums=[0, 1])
|
||||
// 3. A python keyword argument has gradients
|
||||
//
|
||||
// In case 1 we return the original python variable but with the gradients.
|
||||
// In case 2 we return a tuple of the above.
|
||||
// In case 3 we return a tuple containing a tuple and dict (sth like
|
||||
// (tuple(), dict(x=mx.array(5))) ).
|
||||
py::object positional_grads;
|
||||
py::object keyword_grads;
|
||||
py::object py_grads;
|
||||
|
||||
// Collect the gradients for the positional arguments
|
||||
if (argnums.size() == 1) {
|
||||
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
|
||||
} else if (argnums.size() > 1) {
|
||||
py::tuple grads_(argnums.size());
|
||||
for (int i = 0; i < argnums.size(); i++) {
|
||||
grads_[i] = tree_unflatten(args[argnums[i]], gradients, counts[i]);
|
||||
}
|
||||
positional_grads = py::cast<py::object>(grads_);
|
||||
} else {
|
||||
positional_grads = py::none();
|
||||
}
|
||||
|
||||
// No keyword argument gradients so return the tuple of gradients
|
||||
if (argnames.size() == 0) {
|
||||
py_grads = positional_grads;
|
||||
} else {
|
||||
py::dict grads_;
|
||||
for (int i = 0; i < argnames.size(); i++) {
|
||||
auto& k = argnames[i];
|
||||
grads_[k.c_str()] = tree_unflatten(
|
||||
kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
|
||||
}
|
||||
keyword_grads = py::cast<py::object>(grads_);
|
||||
|
||||
py_grads =
|
||||
py::cast<py::object>(py::make_tuple(positional_grads, keyword_grads));
|
||||
}
|
||||
|
||||
// Put the values back in the container
|
||||
py::object return_value = tree_unflatten(py_value_out, value);
|
||||
return std::make_pair(return_value, py_grads);
|
||||
};
|
||||
}
|
||||
|
||||
auto py_vmap(
|
||||
const py::function& fun,
|
||||
const py::object& in_axes,
|
||||
const py::object& out_axes) {
|
||||
return [fun, in_axes, out_axes](const py::args& args) {
|
||||
auto axes_to_flat_tree = [](const py::object& tree,
|
||||
const py::object& axes) {
|
||||
auto tree_axes = tree_map(
|
||||
{tree, axes},
|
||||
[](const std::vector<py::object>& inputs) { return inputs[1]; });
|
||||
std::vector<int> flat_axes;
|
||||
tree_visit(tree_axes, [&flat_axes](py::handle obj) {
|
||||
if (obj.is_none()) {
|
||||
flat_axes.push_back(-1);
|
||||
} else if (py::isinstance<py::int_>(obj)) {
|
||||
flat_axes.push_back(py::cast<int>(py::cast<py::int_>(obj)));
|
||||
} else {
|
||||
throw std::invalid_argument("[vmap] axis must be int or None.");
|
||||
}
|
||||
});
|
||||
return flat_axes;
|
||||
};
|
||||
|
||||
// Inputs must be array or tree of arrays
|
||||
auto inputs = tree_flatten(args, true);
|
||||
auto flat_in_axes = axes_to_flat_tree(args, in_axes);
|
||||
|
||||
// py_value_out will hold the output of the python function in order to be
|
||||
// able to reconstruct the python tree of extra return values
|
||||
py::object py_outputs;
|
||||
|
||||
auto vmap_fn =
|
||||
[&fun, &args, &inputs, &py_outputs](const std::vector<array>& a) {
|
||||
// Call the python function
|
||||
py_outputs = fun(*tree_unflatten(args, a));
|
||||
|
||||
// Flatten the outputs
|
||||
return tree_flatten(py_outputs, true);
|
||||
};
|
||||
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
|
||||
|
||||
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes);
|
||||
|
||||
// Perform the vmap
|
||||
auto outputs = detail::vmap_replace(
|
||||
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
|
||||
|
||||
// Put the outputs back in the container
|
||||
return tree_unflatten(py_outputs, outputs);
|
||||
};
|
||||
}
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
m.def(
|
||||
"eval",
|
||||
[](const py::args& args, bool retain_graph) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
eval(arrays, retain_graph);
|
||||
},
|
||||
"retain_graph"_a = false,
|
||||
R"pbdoc(
|
||||
Evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
Args:
|
||||
*args (arrays or trees of arrays): Each argument can be a single array
|
||||
or a tree of arrays. If a tree is given the nodes can be a Python
|
||||
:class:`list`, :class:`tuple` or :class:`dict` but the leafs must all be
|
||||
an :class:`array`.
|
||||
retain_graph (bool): Indicate that the graph structure should be
|
||||
preserved. This option is intended to enable function transforms
|
||||
which contain control flow based on the value of an array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"jvp",
|
||||
[](const py::function& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
py::args args = py::tuple(primals.size());
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
args[i] = primals[i];
|
||||
}
|
||||
auto out = fun(*args);
|
||||
if (py::isinstance<array>(out)) {
|
||||
return std::vector<array>{py::cast<array>(out)};
|
||||
} else {
|
||||
return py::cast<std::vector<array>>(out);
|
||||
}
|
||||
};
|
||||
return jvp(vfun, primals, tangents);
|
||||
},
|
||||
"fun"_a,
|
||||
"primals"_a,
|
||||
"tangents"_a,
|
||||
R"pbdoc(
|
||||
Compute the Jacobian-vector product.
|
||||
|
||||
This computes the product of the Jacobian of a function ``fun`` evaluated
|
||||
at ``primals`` with the ``tangents``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
tangents (list(array)): A list of :class:`array` which are the
|
||||
"vector" in the Jacobian-vector product. The ``tangents`` should be the
|
||||
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
|
||||
|
||||
Returns:
|
||||
list(array): A list of the Jacobian-vector products which
|
||||
is the same in number, shape, and type of the inputs to ``fun``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"vjp",
|
||||
[](const py::function& fun,
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents) {
|
||||
auto vfun = [&fun](const std::vector<array>& primals) {
|
||||
py::args args = py::tuple(primals.size());
|
||||
for (int i = 0; i < primals.size(); ++i) {
|
||||
args[i] = primals[i];
|
||||
}
|
||||
auto out = fun(*args);
|
||||
if (py::isinstance<array>(out)) {
|
||||
return std::vector<array>{py::cast<array>(out)};
|
||||
} else {
|
||||
return py::cast<std::vector<array>>(out);
|
||||
}
|
||||
};
|
||||
return vjp(vfun, primals, cotangents);
|
||||
},
|
||||
"fun"_a,
|
||||
"primals"_a,
|
||||
"cotangents"_a,
|
||||
R"pbdoc(
|
||||
Compute the vector-Jacobian product.
|
||||
|
||||
Computes the product of the ``cotangents`` with the Jacobian of a
|
||||
function ``fun`` evaluated at ``primals``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
cotangents (list(array)): A list of :class:`array` which are the
|
||||
"vector" in the vector-Jacobian product. The ``cotangents`` should be the
|
||||
same in number, shape, and type as the outputs of ``fun``.
|
||||
|
||||
Returns:
|
||||
list(array): A list of the vector-Jacobian products which
|
||||
is the same in number, shape, and type of the outputs of ``fun``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"value_and_grad",
|
||||
[](const py::function& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto [argnums_vec, argnames_vec] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
return py::cpp_function(py_value_and_grad(
|
||||
fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
R"pbdoc(
|
||||
Returns a function which computes the value and gradient of ``fun``.
|
||||
|
||||
The function passed to :func:`value_and_grad` should return either
|
||||
a scalar loss or a tuple in which the first element is a scalar
|
||||
loss and the remaining elements can be anything.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def mse(params, inputs, targets):
|
||||
outputs = forward(params, inputs)
|
||||
lvalue = (outputs - targets).square().mean()
|
||||
return lvalue
|
||||
|
||||
# Returns lvalue, dlvalue/dparams
|
||||
lvalue, grads = mx.value_and_grad(mse)
|
||||
|
||||
def lasso(params, inputs, targets, a=1.0, b=1.0):
|
||||
outputs = forward(params, inputs)
|
||||
mse = (outputs - targets).square().mean()
|
||||
l1 = mx.abs(outputs - targets).mean()
|
||||
|
||||
loss = a*mse + b*l1
|
||||
|
||||
return loss, mse, l1
|
||||
|
||||
(loss, mse, l1), grads = mx.value_and_grad(lasso)
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array` or a tuple the first element
|
||||
of which should be a scalar :class:`array`.
|
||||
argnums (int or list(int), optional): Specify the index (or indices)
|
||||
of the positional arguments of ``fun`` to compute the gradient
|
||||
with respect to. If neither ``argnums`` nor ``argnames`` are
|
||||
provided ``argnums`` defaults to ``0`` indicating ``fun``'s first
|
||||
argument.
|
||||
argnames (str or list(str), optional): Specify keyword arguments of
|
||||
``fun`` to compute gradients with respect to. It defaults to [] so
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
function: A function which returns a tuple where the first element
|
||||
is the output of `fun` and the second element is the gradients w.r.t.
|
||||
the loss.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"grad",
|
||||
[](const py::function& fun,
|
||||
const std::optional<IntOrVec>& argnums,
|
||||
const StrOrVec& argnames) {
|
||||
auto [argnums_vec, argnames_vec] =
|
||||
validate_argnums_argnames(argnums, argnames);
|
||||
auto fn =
|
||||
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
|
||||
return py::cpp_function(
|
||||
[fn](const py::args& args, const py::kwargs& kwargs) {
|
||||
return fn(args, kwargs).second;
|
||||
});
|
||||
},
|
||||
"fun"_a,
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
R"pbdoc(
|
||||
Returns a function which computes the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array`.
|
||||
argnums (int or list(int), optional): Specify the index (or indices)
|
||||
of the positional arguments of ``fun`` to compute the gradient
|
||||
with respect to. If neither ``argnums`` nor ``argnames`` are
|
||||
provided ``argnums`` defaults to ``0`` indicating ``fun``'s first
|
||||
argument.
|
||||
argnames (str or list(str), optional): Specify keyword arguments of
|
||||
``fun`` to compute gradients with respect to. It defaults to [] so
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
function: A function which has the same input arguments as ``fun`` and
|
||||
returns the gradient(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"vmap",
|
||||
[](const py::function& fun,
|
||||
const py::object& in_axes,
|
||||
const py::object& out_axes) {
|
||||
return py::cpp_function(py_vmap(fun, in_axes, out_axes));
|
||||
},
|
||||
"fun"_a,
|
||||
"in_axes"_a = 0,
|
||||
"out_axes"_a = 0,
|
||||
R"pbdoc(
|
||||
Returns a vectorized version of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (function): A function which takes a variable number of
|
||||
:class:`array` or a tree of :class:`array` and returns
|
||||
a variable number of :class:`array` or a tree of :class:`array`.
|
||||
in_axes (int, optional): An integer or a valid prefix tree of the
|
||||
inputs to ``fun`` where each node specifies the vmapped axis. If
|
||||
the value is ``None`` then the corresponding input(s) are not vmapped.
|
||||
Defaults to ``0``.
|
||||
out_axes (int, optional): An integer or a valid prefix tree of the
|
||||
outputs of ``fun`` where each node specifies the vmapped axis. If
|
||||
the value is ``None`` then the corresponding outputs(s) are not vmapped.
|
||||
Defaults to ``0``.
|
||||
|
||||
Returns:
|
||||
function: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"simplify",
|
||||
[](const py::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
simplify(arrays);
|
||||
},
|
||||
R"pbdoc(
|
||||
Simplify the graph that computes the arrays.
|
||||
|
||||
Run a few fast graph simplification operations to reuse computation and
|
||||
reduce memory consumption. This function is meant to be run every time
|
||||
so its overhead should be small, approximately 1ms for a graph with a
|
||||
few thousand nodes.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
def foo(x):
|
||||
y = x @ x
|
||||
z = x @ x
|
||||
return y + z
|
||||
|
||||
x = mx.ones((10, 10))
|
||||
y = foo(x)
|
||||
z = foo(x)
|
||||
|
||||
# Computes the matmul twice
|
||||
mx.eval(y)
|
||||
|
||||
# Computes the matmul once
|
||||
mx.simplify(z)
|
||||
mx.eval(z)
|
||||
|
||||
Args:
|
||||
args: Any number of arrays and/or trees of arrays to be simplified.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
[](py::object file, const py::args& args) {
|
||||
std::vector<array> arrays = tree_flatten(args);
|
||||
if (py::isinstance<py::str>(file)) {
|
||||
std::ofstream out(py::cast<std::string>(file));
|
||||
export_to_dot(out, arrays);
|
||||
} else if (py::hasattr(file, "write")) {
|
||||
std::ostringstream out;
|
||||
export_to_dot(out, arrays);
|
||||
auto write = file.attr("write");
|
||||
write(out.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"export_to_dot accepts file-like objects or strings to be used as filenames");
|
||||
}
|
||||
},
|
||||
"file"_a);
|
||||
}
|
16
python/tests/mlx_tests.py
Normal file
16
python/tests/mlx_tests.py
Normal file
@ -0,0 +1,16 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class MLXTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.default = mx.default_device()
|
||||
device = os.getenv("DEVICE", None)
|
||||
if device is not None:
|
||||
device = getattr(mx, device)
|
||||
mx.set_default_device(device)
|
||||
|
||||
def tearDown(self):
|
||||
mx.set_default_device(self.default)
|
445
python/tests/test_blas.py
Normal file
445
python/tests/test_blas.py
Normal file
@ -0,0 +1,445 @@
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
import math
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestBlas(mlx_tests.MLXTestCase):
|
||||
@property
|
||||
def dtypes(self):
|
||||
return ["float32", "float16"] if mx.metal.is_available() else ["float32"]
|
||||
|
||||
def __gemm_test(
|
||||
self,
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype=np.float32,
|
||||
f_np_a=lambda x: x,
|
||||
f_np_b=lambda x: x,
|
||||
f_mx_a=lambda x: x,
|
||||
f_mx_b=lambda x: x,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=np.dtype(np_dtype).name, shape_a=shape_a, shape_b=shape_b
|
||||
):
|
||||
np.random.seed(42)
|
||||
scale = max(np.sum(shape_a), 128)
|
||||
a_np = np.random.normal(0.0, 1.0 / scale, shape_a).astype(np_dtype)
|
||||
b_np = np.random.normal(0.0, 1.0 / scale, shape_b).astype(np_dtype)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
|
||||
a_np = f_np_a(a_np.astype(np.float32))
|
||||
b_np = f_np_b(b_np.astype(np.float32))
|
||||
a_mx = f_mx_a(a_mx)
|
||||
b_mx = f_mx_b(b_mx)
|
||||
|
||||
out_npy = a_np @ b_np
|
||||
out_mlx = a_mx @ b_mx
|
||||
|
||||
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
|
||||
self.assertTrue(np.allclose(out_mlx, out_npy.astype(np_dtype), atol=1e-5))
|
||||
|
||||
def test_matmul_unaligned(self):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
for dtype in self.dtypes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
base_shapes = [4, 8, 16, 32, 64, 128]
|
||||
pertubations = [-2, -1, 0, 1, 2]
|
||||
|
||||
for dim in base_shapes:
|
||||
for p in pertubations:
|
||||
shape_a = (dim + p, dim + p)
|
||||
shape_b = (dim + p, dim + p)
|
||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||
|
||||
def test_matmul_shapes(self):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
shapes = [
|
||||
(1, 2, 1, 1),
|
||||
(1, 1, 2, 1),
|
||||
(3, 23, 457, 3),
|
||||
]
|
||||
|
||||
if mx.default_device() == mx.gpu:
|
||||
shapes += [
|
||||
(16, 768, 768, 128),
|
||||
]
|
||||
|
||||
for dtype in self.dtypes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
for B, M, N, K in shapes:
|
||||
|
||||
with self.subTest(tranpose="nn"):
|
||||
shape_a = (B, M, K)
|
||||
shape_b = (B, K, N)
|
||||
self.__gemm_test(shape_a, shape_b, np_dtype)
|
||||
|
||||
with self.subTest(tranpose="nt"):
|
||||
shape_a = (B, M, K)
|
||||
shape_b = (B, N, K)
|
||||
self.__gemm_test(
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype,
|
||||
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
with self.subTest(tranpose="tn"):
|
||||
shape_a = (B, K, M)
|
||||
shape_b = (B, K, N)
|
||||
self.__gemm_test(
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype,
|
||||
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
with self.subTest(tranpose="tt"):
|
||||
shape_a = (B, K, M)
|
||||
shape_b = (B, N, K)
|
||||
self.__gemm_test(
|
||||
shape_a,
|
||||
shape_b,
|
||||
np_dtype,
|
||||
f_np_a=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_a=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
f_np_b=lambda x: np.transpose(x, (0, 2, 1)),
|
||||
f_mx_b=lambda x: mx.transpose(x, (0, 2, 1)),
|
||||
)
|
||||
|
||||
def test_matmul(self):
|
||||
# Note: so far, matmul only works with floating-point types
|
||||
a = mx.array([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
b = mx.array([[0.0, -1.0], [-3.0, 3.0]])
|
||||
|
||||
expected = [[-6.0, 5.0], [-12.0, 9.0]]
|
||||
|
||||
self.assertEqual((a @ b).tolist(), expected)
|
||||
self.assertEqual(mx.matmul(a, b).tolist(), expected)
|
||||
|
||||
# Transposed matmul
|
||||
np.random.seed(0)
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
|
||||
d_npy = np.transpose(a_npy, (1, 0)) @ b_npy
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
|
||||
d_mlx = mx.transpose(a_mlx, (1, 0)) @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-6))
|
||||
|
||||
def test_matmul_dtypes(self):
|
||||
|
||||
for dt in self.dtypes:
|
||||
a_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
||||
getattr(np, dt)
|
||||
)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 256, (16, 16, 16)).astype(
|
||||
getattr(np, dt)
|
||||
)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
c_npy = np.matmul(a_npy, b_npy, dtype=getattr(np, dt))
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
def test_matmul_batched(self):
|
||||
np.random.seed(0)
|
||||
# Batched matmul
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ b_npy
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Batched and transposed matmul
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ np.transpose(b_npy, (0, 2, 1))
|
||||
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 2, 1))
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Batched matmul with simple broadast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
c_npy = a_npy @ b_npy
|
||||
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Both operands broadcasted
|
||||
d_npy = np.broadcast_to(b_npy, (5, 16, 16))
|
||||
d_mlx = mx.broadcast_to(b_mlx, (5, 16, 16))
|
||||
|
||||
e_npy = d_npy @ d_npy
|
||||
e_mlx = d_mlx @ d_mlx
|
||||
|
||||
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
|
||||
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
|
||||
|
||||
# Batched and transposed matmul with simple broadast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
c_npy = a_npy @ np.transpose(b_npy, (1, 0))
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (1, 0))
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
c_npy = a_npy @ b_npy
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
# Test Multiheaded attention style matmul
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (64, 16, 4, 32)).astype(np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
b_mlx = mx.array(b_npy)
|
||||
|
||||
a_npy = np.transpose(a_npy, (0, 2, 1, 3))
|
||||
b_npy = np.transpose(b_npy, (0, 2, 1, 3))
|
||||
a_mlx = mx.transpose(a_mlx, (0, 2, 1, 3))
|
||||
b_mlx = mx.transpose(b_mlx, (0, 2, 1, 3))
|
||||
|
||||
c_npy = a_npy @ np.transpose(b_npy, (0, 1, 3, 2))
|
||||
c_mlx = a_mlx @ mx.transpose(b_mlx, (0, 1, 3, 2))
|
||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||
|
||||
def __gemv_test(
|
||||
self,
|
||||
shape_mat,
|
||||
shape_vec,
|
||||
np_dtype=np.float32,
|
||||
mat_first=True,
|
||||
np_mat_f=lambda x: x,
|
||||
np_vec_f=lambda x: x,
|
||||
mlx_mat_f=lambda x: x,
|
||||
mlx_vec_f=lambda x: x,
|
||||
):
|
||||
with self.subTest(shape=shape_mat):
|
||||
np.random.seed(42)
|
||||
scale = max(np.sum(shape_mat), 32)
|
||||
mat_npy = np.random.normal(0.0, 1.0 / scale, shape_mat).astype(np_dtype)
|
||||
vec_npy = np.random.normal(0.0, 1.0 / scale, shape_vec).astype(np_dtype)
|
||||
|
||||
mat_mlx = mx.array(mat_npy)
|
||||
vec_mlx = mx.array(vec_npy)
|
||||
|
||||
mat_npy = np_mat_f(mat_npy)
|
||||
vec_npy = np_vec_f(vec_npy)
|
||||
mat_mlx = mlx_mat_f(mat_mlx)
|
||||
vec_mlx = mlx_vec_f(vec_mlx)
|
||||
|
||||
if mat_first:
|
||||
out_npy = mat_npy @ vec_npy
|
||||
out_mlx = mat_mlx @ vec_mlx
|
||||
else:
|
||||
out_npy = vec_npy @ mat_npy
|
||||
out_mlx = vec_mlx @ mat_mlx
|
||||
|
||||
self.assertListEqual(list(out_npy.shape), list(out_mlx.shape))
|
||||
self.assertTrue(np.allclose(out_mlx, out_npy, atol=1e-5))
|
||||
|
||||
def test_matrix_vector(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Basic square matrix test
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64), shape_vec=(64, 1), np_dtype=np_dtype
|
||||
)
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64),
|
||||
shape_vec=(64, 1),
|
||||
np_dtype=np_dtype,
|
||||
mat_first=False,
|
||||
np_vec_f=lambda x: np.transpose(x, (1, 0)),
|
||||
mlx_vec_f=lambda x: mx.transpose(x, (1, 0)),
|
||||
)
|
||||
|
||||
# Vector matrix product with aligned and unaligned shapes
|
||||
for in_len_base, out_len_base in (
|
||||
(2, 2),
|
||||
(32, 32),
|
||||
(64, 64),
|
||||
(2048, 2048),
|
||||
):
|
||||
for mi in (-1, 0, 1):
|
||||
for mj in (-1, 0, 1):
|
||||
# Vec mat
|
||||
shape_mat = (in_len_base + mi, out_len_base + mj)
|
||||
shape_vec = (1, in_len_base + mi)
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
# Mat vec
|
||||
shape_mat = (out_len_base + mj, in_len_base + mi)
|
||||
shape_vec = (in_len_base + mi, 1)
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
def test_matrix_vector_batched(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Batched mat vec
|
||||
for shape_mat, shape_vec in (
|
||||
((32, 128, 64), (32, 64, 1)),
|
||||
((128, 64), (32, 64, 1)),
|
||||
((32, 128, 64), (64, 1)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=True, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
# Batched vec mat
|
||||
for shape_vec, shape_mat in (
|
||||
((32, 1, 128), (32, 128, 64)),
|
||||
((32, 1, 128), (128, 64)),
|
||||
((1, 128), (32, 128, 64)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat, shape_vec, mat_first=False, np_dtype=np_dtype
|
||||
)
|
||||
|
||||
def test_matrix_vector_broadcast(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
# Different broadcasts mat vec
|
||||
for shape_mat, shape_vec in (
|
||||
((32, 64, 64), (32, 64, 1)),
|
||||
((64, 64), (32, 64, 1)),
|
||||
((32, 64, 64), (64, 1)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64),
|
||||
shape_vec=(64, 1),
|
||||
np_dtype=np_dtype,
|
||||
np_mat_f=(lambda mat_npy: np.broadcast_to(mat_npy, shape_mat)),
|
||||
np_vec_f=(lambda vec_npy: np.broadcast_to(vec_npy, shape_vec)),
|
||||
mlx_mat_f=(lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat)),
|
||||
mlx_vec_f=(lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec)),
|
||||
)
|
||||
|
||||
# Different broadcasts vec mat
|
||||
for shape_vec, shape_mat in (
|
||||
((32, 1, 64), (32, 64, 64)),
|
||||
((32, 1, 64), (64, 64)),
|
||||
((1, 64), (32, 64, 64)),
|
||||
):
|
||||
self.__gemv_test(
|
||||
shape_mat=(64, 64),
|
||||
shape_vec=(1, 64),
|
||||
np_dtype=np_dtype,
|
||||
mat_first=False,
|
||||
np_mat_f=lambda mat_npy: np.broadcast_to(mat_npy, shape_mat),
|
||||
np_vec_f=lambda vec_npy: np.broadcast_to(vec_npy, shape_vec),
|
||||
mlx_mat_f=lambda mat_mlx: mx.broadcast_to(mat_mlx, shape_mat),
|
||||
mlx_vec_f=lambda vec_mlx: mx.broadcast_to(vec_mlx, shape_vec),
|
||||
)
|
||||
|
||||
def test_matrix_vector_edgecases(self):
|
||||
for dtype in self.dtypes:
|
||||
with self.subTest(dtype=dtype):
|
||||
np_dtype = getattr(np, dtype)
|
||||
|
||||
for in_vec_len in np.arange(1, 5):
|
||||
for out_vec_len in np.arange(1, 5):
|
||||
for batch_size in np.arange(1, 5):
|
||||
with self.subTest(
|
||||
problem_shape=(batch_size, in_vec_len, out_vec_len)
|
||||
):
|
||||
# Matrix vector
|
||||
with self.subTest(transpose=False):
|
||||
a_npy = np.ones(
|
||||
(batch_size, out_vec_len, in_vec_len),
|
||||
dtype=np_dtype,
|
||||
)
|
||||
b_npy = np.ones(
|
||||
(batch_size, in_vec_len, 1), dtype=np_dtype
|
||||
)
|
||||
for i in range(batch_size):
|
||||
b_npy[i] *= i + 1.0
|
||||
|
||||
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
|
||||
c_npy = a_npy @ b_npy
|
||||
c_mlx = a_mlx @ b_mlx
|
||||
|
||||
self.assertListEqual(
|
||||
list(c_npy.shape), list(c_mlx.shape)
|
||||
)
|
||||
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
||||
|
||||
# Vector matrix
|
||||
with self.subTest(transpose=True):
|
||||
a_npy = np.ones(
|
||||
(batch_size, out_vec_len, in_vec_len),
|
||||
dtype=np_dtype,
|
||||
)
|
||||
b_npy = np.ones(
|
||||
(batch_size, 1, out_vec_len), dtype=np_dtype
|
||||
)
|
||||
for i in range(batch_size):
|
||||
b_npy[i] *= i + 1.0
|
||||
|
||||
a_mlx, b_mlx = map(mx.array, [a_npy, b_npy])
|
||||
c_npy = b_npy @ a_npy
|
||||
c_mlx = b_mlx @ a_mlx
|
||||
|
||||
self.assertListEqual(
|
||||
list(c_npy.shape), list(c_mlx.shape)
|
||||
)
|
||||
self.assertTrue(np.array_equal(c_mlx, c_npy))
|
445
python/tests/test_conv.py
Normal file
445
python/tests/test_conv.py
Normal file
@ -0,0 +1,445 @@
|
||||
import unittest
|
||||
from itertools import permutations
|
||||
|
||||
import math
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
import mlx_tests
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
has_torch = True
|
||||
except ImportError as e:
|
||||
has_torch = False
|
||||
|
||||
|
||||
class TestConv(mlx_tests.MLXTestCase):
|
||||
def test_numpy_conv(self):
|
||||
for dtype in (
|
||||
"float16",
|
||||
"float32",
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
for M, N, mode in (
|
||||
(1, 1, "full"),
|
||||
(25, 5, "full"),
|
||||
(24, 5, "same"),
|
||||
(24, 4, "same"),
|
||||
(24, 4, "valid"),
|
||||
(4, 24, "full"),
|
||||
(5, 25, "same"),
|
||||
(4, 25, "valid"),
|
||||
):
|
||||
with self.subTest(dtype=dtype, M=M, N=N, mode=mode):
|
||||
atol = 1e-6 if dtype == "float32" else 1e-5
|
||||
a_np = np.random.rand(M).astype(np_dtype)
|
||||
v_np = np.random.rand(N).astype(np_dtype)
|
||||
a_mx = mx.array(a_np)
|
||||
v_mx = mx.array(v_np)
|
||||
|
||||
c_np = np.convolve(a_np, v_np, mode=mode)
|
||||
c_mx = mx.convolve(a_mx, v_mx, mode=mode)
|
||||
|
||||
self.assertListEqual(list(c_mx.shape), list(c_np.shape))
|
||||
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_1D(self):
|
||||
def run_conv1D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
iH,
|
||||
kH,
|
||||
stride,
|
||||
padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 2, 1)), (in_np, wt_np)
|
||||
)
|
||||
|
||||
out_mx = mx.conv1d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv1d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||
|
||||
# Strided inputs tests
|
||||
for tpose_in, tpose_wt in (
|
||||
((0, 2, 1), (0, 1, 2)),
|
||||
((0, 2, 1), (0, 2, 1)),
|
||||
):
|
||||
with self.subTest(name="strided", tpose_in=tpose_in, tpose_wt=tpose_wt):
|
||||
in_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||
wt_np = np.random.normal(0, 1.0 / 16, (16, 16, 16)).astype(np.float32)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_mx_t = mx.transpose(in_mx, tpose_in)
|
||||
wt_mx_t = mx.transpose(wt_mx, tpose_wt)
|
||||
out_mx = mx.conv1d(in_mx_t, wt_mx_t)
|
||||
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
|
||||
(in_np.transpose(tpose_in), wt_np.transpose(tpose_wt)),
|
||||
)
|
||||
|
||||
out_pt = torch.conv1d(in_pt, wt_pt)
|
||||
out_pt = torch.transpose(out_pt, 2, 1)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_1D_grad(self):
|
||||
def run_conv1D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
iH,
|
||||
kH,
|
||||
stride,
|
||||
padding,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
iH=iH,
|
||||
kH=kH,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
oH = 1 + ((iH + 2 * padding - dilation * (kH - 1) - 1) // stride)
|
||||
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
ct_np = np.random.normal(0, 1.0 / C, (N, oH, O)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||
in_pt, wt_pt, ct_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 2, 1)),
|
||||
(in_np, wt_np, ct_np),
|
||||
)
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv1d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
in_pt.shape,
|
||||
wt_pt,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_wt = F.grad.conv1d_weight(
|
||||
in_pt,
|
||||
wt_pt.shape,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_in = torch.transpose(pt_grad_in, 2, 1).numpy()
|
||||
pt_grad_wt = torch.transpose(pt_grad_wt, 2, 1).numpy()
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
run_conv1D_grad(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_2D(self):
|
||||
def run_conv2D(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
|
||||
(in_np, wt_np),
|
||||
)
|
||||
|
||||
out_mx = mx.conv2d(
|
||||
in_mx,
|
||||
wt_mx,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.conv2d(
|
||||
in_pt,
|
||||
wt_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||
|
||||
self.assertListEqual(list(out_pt.shape), list(out_mx.shape))
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_2D_grad(self):
|
||||
def run_conv2D_grad(
|
||||
N,
|
||||
C,
|
||||
O,
|
||||
idim,
|
||||
kdim,
|
||||
stride,
|
||||
padding,
|
||||
dilation=(1, 1),
|
||||
groups=1,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
dtype=dtype,
|
||||
N=N,
|
||||
C=C,
|
||||
O=O,
|
||||
idim=idim,
|
||||
kdim=kdim,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||
|
||||
oH = 1 + (
|
||||
(iH + 2 * padding[0] - dilation[0] * (kH - 1) - 1) // stride[0]
|
||||
)
|
||||
oW = 1 + (
|
||||
(iW + 2 * padding[1] - dilation[1] * (kW - 1) - 1) // stride[1]
|
||||
)
|
||||
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
ct_np = np.random.normal(0.0, scale, (N, oH, oW, O)).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx, ct_mx = map(mx.array, (in_np, wt_np, ct_np))
|
||||
in_pt, wt_pt, ct_pt = map(
|
||||
lambda x: torch.from_numpy(x.transpose(0, 3, 1, 2)).to("cpu"),
|
||||
(in_np, wt_np, ct_np),
|
||||
)
|
||||
|
||||
def f(a, b):
|
||||
return mx.conv2d(
|
||||
a,
|
||||
b,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
_, outs_mx = mx.vjp(
|
||||
f,
|
||||
[
|
||||
in_mx,
|
||||
wt_mx,
|
||||
],
|
||||
[
|
||||
ct_mx,
|
||||
],
|
||||
)
|
||||
pt_grad_in = F.grad.conv1d_input(
|
||||
in_pt.shape,
|
||||
wt_pt,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_wt = F.grad.conv1d_weight(
|
||||
in_pt,
|
||||
wt_pt.shape,
|
||||
ct_pt,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
pt_grad_in = torch.permute(pt_grad_in, (0, 2, 3, 1)).numpy()
|
||||
pt_grad_wt = torch.permute(pt_grad_wt, (0, 2, 3, 1)).numpy()
|
||||
|
||||
mx_grad_in, mx_grad_wt = outs_mx
|
||||
|
||||
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
|
||||
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
|
||||
|
||||
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
|
||||
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
|
||||
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C, O in (
|
||||
(1, 1, 1),
|
||||
(1, 6, 1),
|
||||
(1, 1, 6),
|
||||
(4, 32, 64),
|
||||
):
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
157
python/tests/test_load.py
Normal file
157
python/tests/test_load.py
Normal file
@ -0,0 +1,157 @@
|
||||
import unittest
|
||||
import os
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import tempfile
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestLoad(mlx_tests.MLXTestCase):
|
||||
dtypes = [
|
||||
"uint8",
|
||||
"uint16",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"int8",
|
||||
"int16",
|
||||
"int32",
|
||||
"int64",
|
||||
"float32",
|
||||
"float16",
|
||||
"complex64",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_save_and_load(self):
|
||||
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
with self.subTest(shape=shape):
|
||||
save_file_mlx = os.path.join(self.test_dir, f"mlx_{dt}_{i}.npy")
|
||||
save_file_npy = os.path.join(self.test_dir, f"npy_{dt}_{i}.npy")
|
||||
|
||||
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
||||
save_arr_npy = save_arr.astype(getattr(np, dt))
|
||||
save_arr_mlx = mx.array(save_arr_npy)
|
||||
|
||||
mx.save(save_file_mlx, save_arr_mlx)
|
||||
np.save(save_file_npy, save_arr_npy)
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
||||
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by numpy as mlx array
|
||||
load_arr_npy_mlx = mx.load(save_file_npy)
|
||||
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by mlx as numpy array
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
def test_save_and_load_fs(self):
|
||||
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
for i, shape in enumerate([(1,), (23,), (1024, 1024), (4, 6, 3, 1, 2)]):
|
||||
with self.subTest(shape=shape):
|
||||
save_file_mlx = os.path.join(
|
||||
self.test_dir, f"mlx_{dt}_{i}_fs.npy"
|
||||
)
|
||||
save_file_npy = os.path.join(
|
||||
self.test_dir, f"npy_{dt}_{i}_fs.npy"
|
||||
)
|
||||
|
||||
save_arr = np.random.uniform(0.0, 32.0, size=shape)
|
||||
save_arr_npy = save_arr.astype(getattr(np, dt))
|
||||
save_arr_mlx = mx.array(save_arr_npy)
|
||||
|
||||
with open(save_file_mlx, "wb") as f:
|
||||
mx.save(f, save_arr_mlx)
|
||||
|
||||
np.save(save_file_npy, save_arr_npy)
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
with open(save_file_mlx, "rb") as f:
|
||||
load_arr_mlx_mlx = mx.load(f)
|
||||
self.assertTrue(mx.array_equal(load_arr_mlx_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by numpy as mlx array
|
||||
with open(save_file_npy, "rb") as f:
|
||||
load_arr_npy_mlx = mx.load(f)
|
||||
self.assertTrue(mx.array_equal(load_arr_npy_mlx, save_arr_mlx))
|
||||
|
||||
# Load array saved by mlx as numpy array
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
self.assertTrue(np.array_equal(load_arr_mlx_npy, save_arr_npy))
|
||||
|
||||
def test_savez_and_loadz(self):
|
||||
if not os.path.isdir(self.test_dir):
|
||||
os.mkdir(self.test_dir)
|
||||
|
||||
for dt in self.dtypes:
|
||||
with self.subTest(dtype=dt):
|
||||
shapes = [(6,), (6, 6), (4, 1, 3, 1, 2)]
|
||||
save_file_mlx_uncomp = os.path.join(
|
||||
self.test_dir, f"mlx_{dt}_uncomp.npz"
|
||||
)
|
||||
save_file_npy_uncomp = os.path.join(
|
||||
self.test_dir, f"npy_{dt}_uncomp.npz"
|
||||
)
|
||||
save_file_mlx_comp = os.path.join(self.test_dir, f"mlx_{dt}_comp.npz")
|
||||
save_file_npy_comp = os.path.join(self.test_dir, f"npy_{dt}_comp.npz")
|
||||
|
||||
# Make dictionary of multiple
|
||||
save_arrs_npy = {
|
||||
f"save_arr_{i}": np.random.uniform(
|
||||
0.0, 32.0, size=shapes[i]
|
||||
).astype(getattr(np, dt))
|
||||
for i in range(len(shapes))
|
||||
}
|
||||
save_arrs_mlx = {k: mx.array(v) for k, v in save_arrs_npy.items()}
|
||||
|
||||
# Save as npz files
|
||||
np.savez(save_file_npy_uncomp, **save_arrs_npy)
|
||||
mx.savez(save_file_mlx_uncomp, **save_arrs_mlx)
|
||||
np.savez_compressed(save_file_npy_comp, **save_arrs_npy)
|
||||
mx.savez_compressed(save_file_mlx_comp, **save_arrs_mlx)
|
||||
|
||||
for save_file_npy, save_file_mlx in (
|
||||
(save_file_npy_uncomp, save_file_mlx_uncomp),
|
||||
(save_file_npy_comp, save_file_mlx_comp),
|
||||
):
|
||||
|
||||
# Load array saved by mlx as mlx array
|
||||
load_arr_mlx_mlx = mx.load(save_file_mlx)
|
||||
for k, v in load_arr_mlx_mlx.items():
|
||||
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
||||
|
||||
# Load arrays saved by numpy as mlx arrays
|
||||
load_arr_npy_mlx = mx.load(save_file_npy)
|
||||
for k, v in load_arr_npy_mlx.items():
|
||||
self.assertTrue(mx.array_equal(save_arrs_mlx[k], v))
|
||||
|
||||
# Load array saved by mlx as numpy array
|
||||
load_arr_mlx_npy = np.load(save_file_mlx)
|
||||
for k, v in load_arr_mlx_npy.items():
|
||||
self.assertTrue(np.array_equal(save_arrs_npy[k], v))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
231
python/tests/test_nn.py
Normal file
231
python/tests/test_nn.py
Normal file
@ -0,0 +1,231 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
import numpy as np
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestNN(mlx_tests.MLXTestCase):
|
||||
def test_linear(self):
|
||||
inputs = mx.zeros((10, 4))
|
||||
layer = nn.Linear(input_dims=4, output_dims=8)
|
||||
outputs = layer(inputs)
|
||||
self.assertEqual(tuple(outputs.shape), (10, 8))
|
||||
|
||||
def test_cross_entropy(self):
|
||||
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
|
||||
targets = mx.array([0, 1])
|
||||
losses = nn.losses.cross_entropy(logits, targets)
|
||||
self.assertTrue(mx.array_equal(losses, mx.zeros((2,))))
|
||||
|
||||
def test_gelu(self):
|
||||
inputs = [1.15286231, -0.81037411, 0.35816911, 0.77484438, 0.66276414]
|
||||
|
||||
# From: jax.nn.gelu(np.array(inputs), approximate=False)
|
||||
expected = np.array(
|
||||
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
|
||||
)
|
||||
|
||||
out = nn.GELU()(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out, expected))
|
||||
|
||||
# Crudely check the approximations
|
||||
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||
y = nn.gelu(x)
|
||||
y_hat1 = nn.gelu_approx(x)
|
||||
y_hat2 = nn.gelu_fast_approx(x)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
||||
|
||||
def test_group_norm(self):
|
||||
x = mx.arange(100, dtype=mx.float32)
|
||||
x = x.reshape(1, 10, 10, 1)
|
||||
x = mx.broadcast_to(x, (2, 10, 10, 4))
|
||||
x = mx.concatenate([x, 0.5 * x], axis=-1)
|
||||
|
||||
# Group norm in groups last mode
|
||||
g = nn.GroupNorm(2, 8)
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2).mean(axis=1)
|
||||
var = y.reshape(2, -1, 2).var(axis=1)
|
||||
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
|
||||
g.weight = g.weight * 2
|
||||
g.bias = g.bias + 3
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2).mean(axis=1)
|
||||
var = y.reshape(2, -1, 2).var(axis=1)
|
||||
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||
|
||||
# Group norm in groups first mode
|
||||
g = nn.GroupNorm(2, 8, pytorch_compatible=True)
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
|
||||
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
|
||||
self.assertTrue(np.allclose(means, np.zeros_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, np.ones_like(var), atol=1e-6))
|
||||
g.weight = g.weight * 2
|
||||
g.bias = g.bias + 3
|
||||
y = g(x)
|
||||
means = y.reshape(2, -1, 2, 4).mean(axis=(1, -1))
|
||||
var = y.reshape(2, -1, 2, 4).var(axis=(1, -1))
|
||||
self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6))
|
||||
self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6))
|
||||
|
||||
def test_conv1d(self):
|
||||
N = 5
|
||||
L = 12
|
||||
ks = 3
|
||||
C_in = 2
|
||||
C_out = 4
|
||||
x = mx.ones((N, L, C_in))
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)
|
||||
c.weight = mx.ones_like(c.weight)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [N, L - ks + 1, C_out])
|
||||
self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))
|
||||
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out])
|
||||
self.assertTrue("bias" in c.parameters())
|
||||
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
|
||||
self.assertTrue("bias" not in c.parameters())
|
||||
|
||||
def test_conv2d(self):
|
||||
x = mx.ones((4, 8, 8, 3))
|
||||
c = nn.Conv2d(3, 1, 8)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 1, 1, 1])
|
||||
c.weight = mx.ones_like(c.weight) / 8 / 8 / 3
|
||||
y = c(x)
|
||||
self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))
|
||||
|
||||
# 3x3 conv no padding stride 1
|
||||
c = nn.Conv2d(3, 8, 3)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 6, 6, 8])
|
||||
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
|
||||
# 3x3 conv padding 1 stride 1
|
||||
c = nn.Conv2d(3, 8, 3, padding=1)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 8, 8, 8])
|
||||
self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 7, 7] - c.weight[:, :-1, :-1].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 1:7, 7] - c.weight[:, :, :-1].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
self.assertLess(
|
||||
mx.abs(y[:, 7, 1:7] - c.weight[:, :-1, :].sum(axis=(1, 2, 3))).max(),
|
||||
1e-4,
|
||||
)
|
||||
|
||||
# 3x3 conv no padding stride 2
|
||||
c = nn.Conv2d(3, 8, 3, padding=0, stride=2)
|
||||
y = c(x)
|
||||
self.assertEqual(y.shape, [4, 3, 3, 8])
|
||||
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
|
||||
|
||||
def test_sequential(self):
|
||||
x = mx.ones((10, 2))
|
||||
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))
|
||||
y = m(x)
|
||||
self.assertEqual(y.shape, [10, 1])
|
||||
params = m.parameters()
|
||||
self.assertTrue("layers" in params)
|
||||
self.assertEqual(len(params["layers"]), 3)
|
||||
self.assertTrue("weight" in params["layers"][0])
|
||||
self.assertEqual(len(params["layers"][1]), 0)
|
||||
self.assertTrue("weight" in params["layers"][2])
|
||||
|
||||
m.layers[1] = nn.relu
|
||||
y2 = m(x)
|
||||
self.assertTrue(mx.array_equal(y, y2))
|
||||
|
||||
def test_module_utilities(self):
|
||||
m = nn.Sequential(
|
||||
nn.Sequential(nn.Linear(2, 10), nn.relu),
|
||||
nn.Sequential(nn.Linear(10, 10), nn.ReLU()),
|
||||
nn.Linear(10, 1),
|
||||
mx.sigmoid,
|
||||
)
|
||||
|
||||
children = m.children()
|
||||
self.assertTrue(isinstance(children, dict))
|
||||
self.assertEqual(len(children), 1)
|
||||
self.assertTrue(isinstance(children["layers"], list))
|
||||
self.assertEqual(len(children["layers"]), 4)
|
||||
self.assertEqual(children["layers"][3], {})
|
||||
flat_children = tree_flatten(children, is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(flat_children), 3)
|
||||
|
||||
leaves = tree_flatten(m.leaf_modules(), is_leaf=nn.Module.is_module)
|
||||
self.assertEqual(len(leaves), 4)
|
||||
self.assertEqual(leaves[0][0], "layers.0.layers.0")
|
||||
self.assertEqual(leaves[1][0], "layers.1.layers.0")
|
||||
self.assertEqual(leaves[2][0], "layers.1.layers.1")
|
||||
self.assertEqual(leaves[3][0], "layers.2")
|
||||
self.assertTrue(leaves[0][1] is m.layers[0].layers[0])
|
||||
self.assertTrue(leaves[1][1] is m.layers[1].layers[0])
|
||||
self.assertTrue(leaves[2][1] is m.layers[1].layers[1])
|
||||
self.assertTrue(leaves[3][1] is m.layers[2])
|
||||
|
||||
m.eval()
|
||||
|
||||
def assert_not_training(k, m):
|
||||
self.assertFalse(m.training)
|
||||
|
||||
m.apply_to_modules(assert_not_training)
|
||||
|
||||
m.train()
|
||||
|
||||
def assert_training(k, m):
|
||||
self.assertTrue(m.training)
|
||||
|
||||
m.apply_to_modules(assert_training)
|
||||
|
||||
def test_sin_pe(self):
|
||||
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
||||
x = mx.arange(10)
|
||||
y = m(x)
|
||||
|
||||
self.assertEqual(y.shape, [10, 16])
|
||||
similarities = y @ y.T
|
||||
self.assertLess(
|
||||
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
|
||||
)
|
||||
|
||||
def test_io(self):
|
||||
def make_model():
|
||||
return nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 2))
|
||||
|
||||
m = make_model()
|
||||
tdir = tempfile.TemporaryDirectory()
|
||||
file = os.path.join(tdir.name, "model.npz")
|
||||
m.save_weights(file)
|
||||
m_load = make_model()
|
||||
m_load.load_weights(file)
|
||||
tdir.cleanup()
|
||||
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
29
python/tests/test_optimizers.py
Normal file
29
python/tests/test_optimizers.py
Normal file
@ -0,0 +1,29 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.optimizers as opt
|
||||
import mlx.utils
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
def test_optimizers(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10,)), mx.zeros((1,))],
|
||||
"second": mx.zeros((1,)),
|
||||
}
|
||||
grads = mlx.utils.tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
for optim in [opt.SGD(0.1), opt.Adam(0.1)]:
|
||||
update = optim.apply_gradients(grads, params)
|
||||
mx.eval(update)
|
||||
equal_shape = mlx.utils.tree_map(
|
||||
lambda x, y: x.shape == y.shape, params, update
|
||||
)
|
||||
all_equal = all(v for _, v in mlx.utils.tree_flatten(equal_shape))
|
||||
self.assertTrue(all_equal)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
192
python/tests/test_random.py
Normal file
192
python/tests/test_random.py
Normal file
@ -0,0 +1,192 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestRandom(mlx_tests.MLXTestCase):
|
||||
def test_global_rng(self):
|
||||
mx.random.seed(3)
|
||||
a = mx.random.uniform()
|
||||
b = mx.random.uniform()
|
||||
|
||||
mx.random.seed(3)
|
||||
x = mx.random.uniform()
|
||||
y = mx.random.uniform()
|
||||
|
||||
self.assertEqual(a.item(), x.item())
|
||||
self.assertEqual(y.item(), b.item())
|
||||
|
||||
def test_key(self):
|
||||
k1 = mx.random.key(0)
|
||||
k2 = mx.random.key(0)
|
||||
self.assertTrue(mx.array_equal(k1, k2))
|
||||
|
||||
k2 = mx.random.key(1)
|
||||
self.assertFalse(mx.array_equal(k1, k2))
|
||||
|
||||
def test_key_split(self):
|
||||
key = mx.random.key(0)
|
||||
|
||||
k1, k2 = mx.random.split(key)
|
||||
self.assertFalse(mx.array_equal(k1, k2))
|
||||
|
||||
r1, r2 = mx.random.split(key)
|
||||
self.assertTrue(mx.array_equal(k1, r1))
|
||||
self.assertTrue(mx.array_equal(k2, r2))
|
||||
|
||||
keys = mx.random.split(key, 10)
|
||||
self.assertEqual(keys.shape, [10, 2])
|
||||
|
||||
def test_uniform(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.uniform(key=key)
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
b = mx.random.uniform(key=key)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
a = mx.random.uniform(shape=(2, 3))
|
||||
self.assertEqual(a.shape, [2, 3])
|
||||
|
||||
a = mx.random.uniform(shape=(1000,), low=-1, high=5)
|
||||
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||
|
||||
a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5)
|
||||
self.assertTrue(mx.all((a > -1) < 5).item())
|
||||
|
||||
def test_normal(self):
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.normal(key=key)
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
b = mx.random.normal(key=key)
|
||||
self.assertEqual(a.item(), b.item())
|
||||
|
||||
a = mx.random.normal(shape=(2, 3))
|
||||
self.assertEqual(a.shape, [2, 3])
|
||||
|
||||
## Generate in float16 or bfloat16
|
||||
for t in [mx.float16, mx.bfloat16]:
|
||||
a = mx.random.normal(dtype=t)
|
||||
self.assertEqual(a.dtype, t)
|
||||
|
||||
def test_randint(self):
|
||||
a = mx.random.randint(0, 1, [])
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
shape = [88]
|
||||
low = mx.array(3)
|
||||
high = mx.array(15)
|
||||
|
||||
key = mx.random.key(0)
|
||||
a = mx.random.randint(low, high, shape, key=key)
|
||||
self.assertEqual(a.shape, shape)
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
|
||||
# Check using the same key yields the same value
|
||||
b = mx.random.randint(low, high, shape, key=key)
|
||||
self.assertListEqual(a.tolist(), b.tolist())
|
||||
|
||||
shape = [3, 4]
|
||||
low = mx.reshape(mx.array([0] * 3), [3, 1])
|
||||
high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4])
|
||||
|
||||
a = mx.random.randint(low, high, shape)
|
||||
self.assertEqual(a.shape, shape)
|
||||
|
||||
a = mx.random.randint(-10, 10, [1000, 1000])
|
||||
self.assertTrue(mx.all(-10 <= a).item() and mx.all(a < 10).item())
|
||||
|
||||
a = mx.random.randint(10, -10, [1000, 1000])
|
||||
self.assertTrue(mx.all(a == 10).item())
|
||||
|
||||
def test_bernoulli(self):
|
||||
a = mx.random.bernoulli()
|
||||
self.assertEqual(a.shape, [])
|
||||
self.assertEqual(a.dtype, mx.bool_)
|
||||
|
||||
a = mx.random.bernoulli(mx.array(0.5), [5])
|
||||
self.assertEqual(a.shape, [5])
|
||||
|
||||
a = mx.random.bernoulli(mx.array([2.0, -2.0]))
|
||||
self.assertEqual(a.tolist(), [True, False])
|
||||
self.assertEqual(a.shape, [2])
|
||||
|
||||
p = mx.array([0.1, 0.2, 0.3])
|
||||
mx.reshape(p, [1, 3])
|
||||
x = mx.random.bernoulli(p, [4, 3])
|
||||
self.assertEqual(x.shape, [4, 3])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.bernoulli(p, [2]) # Bad shape
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.bernoulli(0, [2]) # Bad type
|
||||
|
||||
def test_truncated_normal(self):
|
||||
a = mx.random.truncated_normal(-2.0, 2.0)
|
||||
self.assertEqual(a.size, 1)
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
|
||||
a = mx.random.truncated_normal(mx.array([]), mx.array([]))
|
||||
self.assertEqual(a.dtype, mx.float32)
|
||||
self.assertEqual(a.size, 0)
|
||||
|
||||
lower = mx.reshape(mx.array([-2.0, 0.0]), [1, 2])
|
||||
upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1])
|
||||
a = mx.random.truncated_normal(lower, upper)
|
||||
|
||||
self.assertEqual(a.shape, [3, 2])
|
||||
self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item())
|
||||
|
||||
a = mx.random.truncated_normal(2.0, -2.0)
|
||||
self.assertTrue(mx.all(a == 2.0).item())
|
||||
|
||||
a = mx.random.truncated_normal(-3.0, 3.0, [542, 399])
|
||||
self.assertEqual(a.shape, [542, 399])
|
||||
|
||||
lower = mx.array([-2.0, -1.0])
|
||||
higher = mx.array([1.0, 2.0, 3.0])
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.truncated_normal(lower, higher) # Bad shape
|
||||
|
||||
def test_gumbel(self):
|
||||
samples = mx.random.gumbel(shape=(100, 100))
|
||||
self.assertEqual(samples.shape, [100, 100])
|
||||
self.assertEqual(samples.dtype, mx.float32)
|
||||
mean = 0.5772
|
||||
# Std deviation of the sample mean is small (<0.02),
|
||||
# so this test is pretty conservative
|
||||
self.assertTrue(mx.abs(mx.mean(samples) - mean) < 0.2)
|
||||
|
||||
def test_categorical(self):
|
||||
logits = mx.zeros((10, 20))
|
||||
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
|
||||
self.assertEqual(mx.random.categorical(logits, 0).shape, [20])
|
||||
self.assertEqual(mx.random.categorical(logits, 1).shape, [10])
|
||||
|
||||
out = mx.random.categorical(logits)
|
||||
self.assertEqual(out.shape, [10])
|
||||
self.assertEqual(out.dtype, mx.uint32)
|
||||
self.assertTrue(mx.max(out).item() < 20)
|
||||
|
||||
out = mx.random.categorical(logits, 0, [5, 20])
|
||||
self.assertEqual(out.shape, [5, 20])
|
||||
self.assertTrue(mx.max(out).item() < 10)
|
||||
|
||||
out = mx.random.categorical(logits, 1, num_samples=7)
|
||||
self.assertEqual(out.shape, [10, 7])
|
||||
out = mx.random.categorical(logits, 0, num_samples=7)
|
||||
self.assertEqual(out.shape, [20, 7])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.random.categorical(logits, shape=[10, 5], num_samples=5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
26
python/tests/test_tree.py
Normal file
26
python/tests/test_tree.py
Normal file
@ -0,0 +1,26 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.utils
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestTreeUtils(mlx_tests.MLXTestCase):
|
||||
def test_tree_map(self):
|
||||
tree = {"a": 0, "b": 1, "c": 2}
|
||||
tree = mlx.utils.tree_map(lambda x: x + 1, tree)
|
||||
|
||||
expected_tree = {"a": 1, "b": 2, "c": 3}
|
||||
self.assertEqual(tree, expected_tree)
|
||||
|
||||
def test_tree_flatten(self):
|
||||
tree = [{"a": 1, "b": 2}, "c"]
|
||||
vals = (1, 2, "c")
|
||||
flat_tree = mlx.utils.tree_flatten(tree)
|
||||
self.assertEqual(list(zip(*flat_tree))[1], vals)
|
||||
self.assertEqual(mlx.utils.tree_unflatten(flat_tree), tree)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
167
python/tests/test_vmap.py
Normal file
167
python/tests/test_vmap.py
Normal file
@ -0,0 +1,167 @@
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestVmap(mlx_tests.MLXTestCase):
|
||||
def test_basics(self):
|
||||
# Can't vmap over scalars
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp)(mx.array(1.0))
|
||||
|
||||
# Invalid input
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp)("hello")
|
||||
|
||||
# Invalid axes
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, in_axes="hello")(mx.array([0, 1]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, in_axes=2)(mx.array([0, 1]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, out_axes="hello")(mx.array([0, 1]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(mx.exp, out_axes=2)(mx.array([0, 1]))
|
||||
|
||||
def test_unary(self):
|
||||
ops = [
|
||||
"abs",
|
||||
"cos",
|
||||
"erf",
|
||||
"erfinv",
|
||||
"exp",
|
||||
"log",
|
||||
"log1p",
|
||||
"log2",
|
||||
"log10",
|
||||
"logical_not",
|
||||
"negative",
|
||||
"reciprocal",
|
||||
"rsqrt",
|
||||
"sigmoid",
|
||||
"sign",
|
||||
"sin",
|
||||
"sqrt",
|
||||
"square",
|
||||
]
|
||||
ops = ["erfinv"]
|
||||
for opname in ops:
|
||||
with self.subTest(op=opname):
|
||||
op = getattr(mx, opname)
|
||||
x = mx.arange(5)
|
||||
y = mx.vmap(op)(x)
|
||||
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||
|
||||
x = mx.arange(8).reshape(2, 4)
|
||||
y = mx.vmap(op)(x)
|
||||
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||
|
||||
y = mx.vmap(op, in_axes=1, out_axes=1)(x)
|
||||
self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
|
||||
|
||||
def test_binary(self):
|
||||
ops = [
|
||||
"add",
|
||||
"divide",
|
||||
"equal",
|
||||
"greater",
|
||||
"greater_equal",
|
||||
"less",
|
||||
"less_equal",
|
||||
"logaddexp",
|
||||
"maximum",
|
||||
"minimum",
|
||||
"multiply",
|
||||
"power",
|
||||
"subtract",
|
||||
]
|
||||
for opname in ops:
|
||||
with self.subTest(op=opname):
|
||||
op = getattr(mx, opname)
|
||||
x = mx.random.uniform(shape=(5,))
|
||||
y = mx.random.uniform(shape=(5,))
|
||||
out = mx.vmap(op)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||
|
||||
x = mx.random.uniform(shape=(2, 4))
|
||||
y = mx.random.uniform(shape=(2, 4))
|
||||
out = mx.vmap(op)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||
|
||||
out = mx.vmap(op, in_axes=(0, 0), out_axes=0)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y)))
|
||||
|
||||
y = mx.random.uniform(shape=(4, 2))
|
||||
out = mx.vmap(op, in_axes=(0, 1), out_axes=0)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y.T)))
|
||||
|
||||
out = mx.vmap(op, in_axes=(0, 1), out_axes=1)(x, y)
|
||||
self.assertTrue(mx.array_equal(out, op(x, y.T).T))
|
||||
|
||||
def test_tree(self):
|
||||
def my_fun(tree):
|
||||
return (tree["a"] + tree["b"][0]) * tree["b"][1]
|
||||
|
||||
tree = {
|
||||
"a": mx.random.uniform(shape=(2, 4)),
|
||||
"b": (
|
||||
mx.random.uniform(shape=(2, 4)),
|
||||
mx.random.uniform(shape=(2, 4)),
|
||||
),
|
||||
}
|
||||
out = mx.vmap(my_fun)(tree)
|
||||
expected = my_fun(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
|
||||
self.assertTrue(mx.array_equal(out, my_fun(tree)))
|
||||
|
||||
tree = {
|
||||
"a": mx.random.uniform(shape=(2, 4)),
|
||||
"b": (
|
||||
mx.random.uniform(shape=(4, 2)),
|
||||
mx.random.uniform(shape=(4, 2)),
|
||||
),
|
||||
}
|
||||
out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
|
||||
expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def my_fun(x, y):
|
||||
return {"a": x + y, "b": x * y}
|
||||
|
||||
x = mx.random.uniform(shape=(2, 4))
|
||||
y = mx.random.uniform(shape=(2, 4))
|
||||
out = mx.vmap(my_fun, in_axes=0, out_axes=0)(x, y)
|
||||
expected = my_fun(x, y)
|
||||
self.assertTrue(mx.array_equal(out["a"], expected["a"]))
|
||||
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes=0, out_axes=(0, 1))(x, y)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.vmap(my_fun, in_axes=0, out_axes={"a": 0, "c": 1})(x, y)
|
||||
|
||||
out = mx.vmap(my_fun, in_axes=0, out_axes={"a": 1, "b": 0})(x, y)
|
||||
expected = my_fun(x, y)
|
||||
self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
|
||||
self.assertTrue(mx.array_equal(out["b"], expected["b"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
127
setup.py
Normal file
127
setup.py
Normal file
@ -0,0 +1,127 @@
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import sysconfig
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import Extension, setup, find_namespace_packages
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
# A CMakeExtension needs a sourcedir instead of a file list.
|
||||
# The name must be the _single_ output extension from the CMake build.
|
||||
# If you need multiple extensions, see scikit-build.
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name: str, sourcedir: str = "") -> None:
|
||||
super().__init__(name, sources=[])
|
||||
self.sourcedir = os.fspath(Path(sourcedir).resolve())
|
||||
|
||||
|
||||
class CMakeBuild(build_ext):
|
||||
def build_extension(self, ext: CMakeExtension) -> None:
|
||||
# Must be in this form due to bug in .resolve() only fixed in Python 3.10+
|
||||
ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call]
|
||||
extdir = ext_fullpath.parent.resolve()
|
||||
|
||||
debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
|
||||
cfg = "Debug" if debug else "Release"
|
||||
|
||||
# CMake lets you override the generator - we need to check this.
|
||||
# Can be set with Conda-Build, for example.
|
||||
cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
|
||||
|
||||
# Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
|
||||
# EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
|
||||
# from Python.
|
||||
cmake_args = [
|
||||
f"-DCMAKE_INSTALL_PREFIX={extdir}{os.sep}",
|
||||
f"-DCMAKE_BUILD_TYPE={cfg}",
|
||||
"-DBUILD_SHARED_LIBS=ON",
|
||||
"-DMLX_BUILD_PYTHON_BINDINGS=ON",
|
||||
"-DMLX_BUILD_TESTS=OFF",
|
||||
"-DMLX_BUILD_BENCHMARKS=OFF",
|
||||
"-DMLX_BUILD_EXAMPLES=OFF",
|
||||
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
|
||||
]
|
||||
build_args = []
|
||||
# Adding CMake arguments set as environment variable
|
||||
# (needed e.g. to build for ARM OSx on conda-forge)
|
||||
if "CMAKE_ARGS" in os.environ:
|
||||
cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
|
||||
|
||||
# Pass version to C++
|
||||
cmake_args += [f"-DMLX_VERSION={self.distribution.get_version()}"] # type: ignore[attr-defined]
|
||||
|
||||
if sys.platform.startswith("darwin"):
|
||||
# Cross-compile support for macOS - respect ARCHFLAGS if set
|
||||
archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", ""))
|
||||
if archs:
|
||||
cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]
|
||||
|
||||
# Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level
|
||||
# across all generators.
|
||||
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
|
||||
# self.parallel is a Python 3 only way to set parallel jobs by hand
|
||||
# using -j in the build_ext call, not supported by pip or PyPA-build.
|
||||
if hasattr(self, "parallel") and self.parallel:
|
||||
# CMake 3.12+ only.
|
||||
build_args += [f"-j{self.parallel}"]
|
||||
|
||||
build_temp = Path(self.build_temp) / ext.name
|
||||
if not build_temp.exists():
|
||||
build_temp.mkdir(parents=True)
|
||||
|
||||
subprocess.run(
|
||||
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
|
||||
)
|
||||
subprocess.run(
|
||||
["cmake", "--build", ".", "--target", "install", *build_args],
|
||||
cwd=build_temp,
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Make sure to copy mlx.metallib for inplace builds
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
# Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102
|
||||
if self.inplace:
|
||||
for ext in self.extensions:
|
||||
if ext.name == "mlx.core":
|
||||
# Resolve inplace package dir
|
||||
build_py = self.get_finalized_command("build_py")
|
||||
inplace_file, regular_file = self._get_inplace_equivalent(
|
||||
build_py, ext
|
||||
)
|
||||
|
||||
inplace_dir = str(Path(inplace_file).parent.resolve())
|
||||
regular_dir = str(Path(regular_file).parent.resolve())
|
||||
|
||||
self.copy_tree(regular_dir, inplace_dir)
|
||||
|
||||
|
||||
# The information here can also be placed in setup.cfg - better separation of
|
||||
# logic and declaration, and simpler if you include description/version in a file.
|
||||
if __name__ == "__main__":
|
||||
packages = find_namespace_packages(
|
||||
where="python", exclude=["src", "tests", "tests.*"]
|
||||
)
|
||||
package_dir = {"": "python"}
|
||||
package_data = {"mlx": ["lib/*", "include/*", "share/*"]}
|
||||
setup(
|
||||
name="mlx",
|
||||
version="0.0.2",
|
||||
author="MLX Contributors",
|
||||
author_email="mlx@group.apple.com",
|
||||
description="A framework for machine learning on Apple Silicon.",
|
||||
long_description="",
|
||||
packages=packages,
|
||||
package_dir=package_dir,
|
||||
package_data=package_data,
|
||||
include_package_data=True,
|
||||
ext_modules=[CMakeExtension("mlx.core")],
|
||||
cmdclass={"build_ext": CMakeBuild},
|
||||
zip_safe=False,
|
||||
python_requires=">=3.7",
|
||||
)
|
41
tests/allocator_tests.cpp
Normal file
41
tests/allocator_tests.cpp
Normal file
@ -0,0 +1,41 @@
|
||||
#include <stdexcept>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test simple allocations") {
|
||||
{
|
||||
auto buffer = allocator::malloc(sizeof(float));
|
||||
auto fptr = static_cast<float*>(buffer.raw_ptr());
|
||||
*fptr = 0.5f;
|
||||
CHECK_EQ(*fptr, 0.5f);
|
||||
allocator::free(buffer);
|
||||
}
|
||||
|
||||
{
|
||||
auto buffer = allocator::malloc(128 * sizeof(int));
|
||||
int* ptr = static_cast<int*>(buffer.raw_ptr());
|
||||
for (int i = 0; i < 128; ++i) {
|
||||
ptr[i] = i;
|
||||
}
|
||||
allocator::free(buffer);
|
||||
}
|
||||
|
||||
{
|
||||
auto buffer = allocator::malloc(0);
|
||||
allocator::free(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test large allocations") {
|
||||
size_t size = 1 << 30;
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
auto buffer = allocator::malloc(size);
|
||||
allocator::free(buffer);
|
||||
}
|
||||
// Shouldn't be able to allocate an exabyte anytime soon.
|
||||
CHECK_THROWS_AS(allocator::malloc(1ull << 60), std::runtime_error);
|
||||
}
|
589
tests/array_tests.cpp
Normal file
589
tests/array_tests.cpp
Normal file
@ -0,0 +1,589 @@
|
||||
#include <climits>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test array basics") {
|
||||
// Scalar
|
||||
array x(1.0);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.ndim(), 0);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{});
|
||||
CHECK_THROWS_AS(x.shape(0), std::out_of_range);
|
||||
CHECK_THROWS_AS(x.shape(-1), std::out_of_range);
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{});
|
||||
CHECK_EQ(x.itemsize(), sizeof(float));
|
||||
CHECK_EQ(x.nbytes(), sizeof(float));
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK_EQ(x.item<float>(), 1.0);
|
||||
|
||||
// Scalar with specified type
|
||||
x = array(1, float32);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK_EQ(x.item<float>(), 1.0);
|
||||
|
||||
// Scalar with specified type
|
||||
x = array(1, bool_);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
CHECK_EQ(x.itemsize(), sizeof(bool));
|
||||
CHECK_EQ(x.nbytes(), sizeof(bool));
|
||||
CHECK_EQ(x.item<bool>(), true);
|
||||
|
||||
// Check shaped arrays
|
||||
x = array({1.0});
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK_EQ(x.size(), 1);
|
||||
CHECK_EQ(x.ndim(), 1);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{1});
|
||||
CHECK_EQ(x.shape(0), 1);
|
||||
CHECK_EQ(x.shape(-1), 1);
|
||||
CHECK_THROWS_AS(x.shape(1), std::out_of_range);
|
||||
CHECK_THROWS_AS(x.shape(-2), std::out_of_range);
|
||||
CHECK_EQ(x.strides(), std::vector<size_t>{1});
|
||||
CHECK_EQ(x.item<float>(), 1.0);
|
||||
|
||||
// Check empty array
|
||||
x = array({});
|
||||
CHECK_EQ(x.size(), 0);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK_EQ(x.itemsize(), sizeof(float));
|
||||
CHECK_EQ(x.nbytes(), 0);
|
||||
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
|
||||
|
||||
x = array({1.0, 1.0});
|
||||
CHECK_EQ(x.size(), 2);
|
||||
CHECK_EQ(x.shape(), std::vector<int>{2});
|
||||
CHECK_EQ(x.itemsize(), sizeof(float));
|
||||
CHECK_EQ(x.nbytes(), x.itemsize() * x.size());
|
||||
|
||||
// Accessing item in non-scalar array throws
|
||||
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument);
|
||||
|
||||
x = array({1.0, 1.0, 1.0}, {1, 3});
|
||||
CHECK(x.size() == 3);
|
||||
CHECK(x.shape() == std::vector<int>{1, 3});
|
||||
CHECK(x.strides() == std::vector<size_t>{3, 1});
|
||||
|
||||
// Test wrong size/shapes throw:
|
||||
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 4}), std::invalid_argument);
|
||||
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 2}), std::invalid_argument);
|
||||
|
||||
// Test array ids work as expected
|
||||
x = array(1.0);
|
||||
auto y = x;
|
||||
CHECK_EQ(y.id(), x.id());
|
||||
array z(2.0);
|
||||
CHECK_NE(z.id(), x.id());
|
||||
z = x;
|
||||
CHECK_EQ(z.id(), x.id());
|
||||
|
||||
// Array creation from pointer
|
||||
float data[] = {0.0, 1.0, 2.0, 3.0};
|
||||
x = array(data, {4});
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK(array_equal(x, array({0.0, 1.0, 2.0, 3.0})).item<bool>());
|
||||
|
||||
// Array creation from vectors
|
||||
{
|
||||
std::vector<int> data = {0, 1, 2, 3};
|
||||
x = array(data.begin(), {4});
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<bool> data = {false, true, false, true};
|
||||
x = array(data.begin(), {4});
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
CHECK(array_equal(x, array({false, true, false, true})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test array types") {
|
||||
#define basic_dtype_test(T, mlx_type) \
|
||||
T val = 42; \
|
||||
array x(val); \
|
||||
CHECK_EQ(x.dtype(), mlx_type); \
|
||||
CHECK_EQ(x.item<T>(), val); \
|
||||
x = array({val, val}); \
|
||||
CHECK_EQ(x.dtype(), mlx_type);
|
||||
|
||||
// bool_
|
||||
{
|
||||
array x(true);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
CHECK_EQ(x.item<bool>(), true);
|
||||
|
||||
x = array({true, false});
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
|
||||
x = array({true, false}, float32);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK(array_equal(x, array({1.0f, 0.0f})).item<bool>());
|
||||
}
|
||||
|
||||
// uint8
|
||||
{ basic_dtype_test(uint8_t, uint8); }
|
||||
|
||||
// uint16
|
||||
{ basic_dtype_test(uint16_t, uint16); }
|
||||
|
||||
// uint32
|
||||
{ basic_dtype_test(uint32_t, uint32); }
|
||||
|
||||
// uint64
|
||||
{ basic_dtype_test(uint64_t, uint64); }
|
||||
|
||||
// int8
|
||||
{ basic_dtype_test(int8_t, int8); }
|
||||
|
||||
// int16
|
||||
{ basic_dtype_test(int16_t, int16); }
|
||||
|
||||
// int32
|
||||
{ basic_dtype_test(int32_t, int32); }
|
||||
|
||||
// int64
|
||||
{ basic_dtype_test(int64_t, int64); }
|
||||
|
||||
// float16
|
||||
{ basic_dtype_test(float16_t, float16); }
|
||||
|
||||
// float32
|
||||
{ basic_dtype_test(float, float32); }
|
||||
|
||||
// bfloat16
|
||||
{ basic_dtype_test(bfloat16_t, bfloat16); }
|
||||
|
||||
// uint32
|
||||
{
|
||||
uint32_t val = UINT_MAX;
|
||||
array x(val);
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
CHECK_EQ(x.item<uint32_t>(), val);
|
||||
|
||||
x = array({1u, 2u});
|
||||
CHECK_EQ(x.dtype(), uint32);
|
||||
}
|
||||
|
||||
// int32
|
||||
{
|
||||
array x(-1);
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
CHECK_EQ(x.item<int>(), -1);
|
||||
|
||||
x = array({-1, 2});
|
||||
CHECK_EQ(x.dtype(), int32);
|
||||
|
||||
std::vector<int> data{0, 1, 2};
|
||||
x = array(data.data(), {static_cast<int>(data.size())}, bool_);
|
||||
CHECK_EQ(x.dtype(), bool_);
|
||||
CHECK(array_equal(x, array({false, true, true})).item<bool>());
|
||||
}
|
||||
|
||||
// int64
|
||||
{
|
||||
int64_t val = static_cast<int64_t>(INT_MIN) - 1;
|
||||
array x(val);
|
||||
CHECK_EQ(x.dtype(), int64);
|
||||
CHECK_EQ(x.item<int64_t>(), val);
|
||||
|
||||
x = array({val, val});
|
||||
CHECK_EQ(x.dtype(), int64);
|
||||
}
|
||||
|
||||
// float32
|
||||
{
|
||||
array x(3.14f);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK_EQ(x.item<float>(), 3.14f);
|
||||
|
||||
x = array(1.25);
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK_EQ(x.item<float>(), 1.25f);
|
||||
|
||||
x = array({1.0f, 2.0f});
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
x = array({1.0, 2.0});
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
|
||||
std::vector<double> data{1.0, 2.0, 4.0};
|
||||
x = array(data.data(), {static_cast<int>(data.size())});
|
||||
CHECK_EQ(x.dtype(), float32);
|
||||
CHECK(array_equal(x, array({1.0f, 2.0f, 4.0f})).item<bool>());
|
||||
}
|
||||
|
||||
// complex64
|
||||
{
|
||||
complex64_t v = {1.0f, 1.0f};
|
||||
array x(v);
|
||||
CHECK_EQ(x.dtype(), complex64);
|
||||
CHECK_EQ(x.item<complex64_t>(), v);
|
||||
|
||||
array y(std::complex<float>{1.0f, 1.0f});
|
||||
CHECK_EQ(x.dtype(), complex64);
|
||||
CHECK_EQ(x.item<complex64_t>(), v);
|
||||
}
|
||||
|
||||
#undef basic_dtype_test
|
||||
|
||||
#define basic_dtype_str_test(s, dtype) \
|
||||
CHECK_EQ(s, dtype_to_array_protocol(dtype)); \
|
||||
CHECK_EQ(dtype_from_array_protocol(s), dtype);
|
||||
|
||||
// To and from str
|
||||
{
|
||||
basic_dtype_str_test("|b1", bool_);
|
||||
basic_dtype_str_test("|u1", uint8);
|
||||
basic_dtype_str_test("<u2", uint16);
|
||||
basic_dtype_str_test("<u4", uint32);
|
||||
basic_dtype_str_test("<u8", uint64);
|
||||
basic_dtype_str_test("|i1", int8);
|
||||
basic_dtype_str_test("<i2", int16);
|
||||
basic_dtype_str_test("<i4", int32);
|
||||
basic_dtype_str_test("<i8", int64);
|
||||
basic_dtype_str_test("<f2", float16);
|
||||
basic_dtype_str_test("<f4", float32);
|
||||
basic_dtype_str_test("<V2", bfloat16);
|
||||
basic_dtype_str_test("<c8", complex64);
|
||||
}
|
||||
|
||||
#undef basic_dtype_str_test
|
||||
}
|
||||
|
||||
TEST_CASE("test array metadata") {
|
||||
array x(1.0f);
|
||||
CHECK_EQ(x.data_size(), 1);
|
||||
CHECK_EQ(x.flags().contiguous, true);
|
||||
CHECK_EQ(x.flags().row_contiguous, true);
|
||||
CHECK_EQ(x.flags().col_contiguous, true);
|
||||
|
||||
x = array({1.0f}, {1, 1, 1});
|
||||
CHECK_EQ(x.data_size(), 1);
|
||||
CHECK_EQ(x.flags().contiguous, true);
|
||||
CHECK_EQ(x.flags().row_contiguous, true);
|
||||
CHECK_EQ(x.flags().col_contiguous, true);
|
||||
|
||||
x = array({1.0f, 1.0f}, {1, 2});
|
||||
CHECK_EQ(x.data_size(), 2);
|
||||
CHECK_EQ(x.flags().contiguous, true);
|
||||
CHECK_EQ(x.flags().row_contiguous, true);
|
||||
CHECK_EQ(x.flags().col_contiguous, true);
|
||||
|
||||
x = zeros({1, 1, 4});
|
||||
eval(x);
|
||||
CHECK_EQ(x.data_size(), 4);
|
||||
CHECK_EQ(x.flags().contiguous, true);
|
||||
CHECK_EQ(x.flags().row_contiguous, true);
|
||||
CHECK_EQ(x.flags().col_contiguous, true);
|
||||
|
||||
x = zeros({2, 4});
|
||||
eval(x);
|
||||
CHECK_EQ(x.data_size(), 8);
|
||||
CHECK_EQ(x.flags().contiguous, true);
|
||||
CHECK_EQ(x.flags().row_contiguous, true);
|
||||
CHECK_EQ(x.flags().col_contiguous, false);
|
||||
|
||||
x = array(1.0f);
|
||||
auto y = broadcast_to(x, {1, 1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
y = broadcast_to(x, {2, 8, 10});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
CHECK_EQ(y.flags().col_contiguous, false);
|
||||
|
||||
y = broadcast_to(x, {1, 0});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 0);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
y = broadcast_to(zeros({4, 2, 1}), {4, 2, 0});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 0);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = array(1.0f);
|
||||
y = transpose(x);
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = ones({1, 1, 1});
|
||||
y = transpose(x);
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = ones({1, 1, 1});
|
||||
y = transpose(x, {0, 1, 2});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = ones({1, 1, 1});
|
||||
y = transpose(x, {1, 2, 0});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = ones({4, 1});
|
||||
y = transpose(x);
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 4);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = ones({2, 3, 4});
|
||||
y = transpose(x);
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 24);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
y = transpose(x, {0, 2, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 24);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
CHECK_EQ(y.flags().col_contiguous, false);
|
||||
|
||||
y = transpose(transpose(x, {0, 2, 1}), {0, 2, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 24);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, false);
|
||||
|
||||
x = array(1.0f);
|
||||
y = reshape(x, {1, 1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = ones({2, 4});
|
||||
y = reshape(x, {8});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 8);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
y = reshape(x, {8, 1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 8);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
y = reshape(x, {1, 8, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 8);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = ones({12});
|
||||
y = reshape(x, {2, 3, 2});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 12);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, false);
|
||||
|
||||
x = array(1.0f);
|
||||
y = slice(x, {}, {});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = array({1.0f});
|
||||
y = slice(x, {-10}, {10}, {10});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||
y = slice(x, {0, 0}, {1, 3}, {1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 3);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||
y = slice(x, {0, 0}, {1, 3}, {1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 3);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||
y = slice(x, {0, 0}, {0, 3}, {1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 0);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||
y = slice(x, {0, 0}, {1, 2}, {1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 2);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = array({1.0f, 2.0f, 3.0f}, {1, 3});
|
||||
y = slice(x, {0, 0}, {1, 2}, {2, 3});
|
||||
eval(y);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 1});
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4});
|
||||
y = slice(x, {0, 0}, {1, 4}, {1, 2});
|
||||
eval(y);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 2});
|
||||
CHECK_EQ(y.flags().contiguous, false);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
CHECK_EQ(y.flags().col_contiguous, false);
|
||||
|
||||
x = broadcast_to(array(1.0f), {4, 10});
|
||||
y = slice(x, {0, 0}, {4, 10}, {2, 2});
|
||||
eval(y);
|
||||
CHECK_EQ(y.shape(), std::vector<int>{2, 5});
|
||||
CHECK_EQ(y.data_size(), 1);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
CHECK_EQ(y.flags().col_contiguous, false);
|
||||
|
||||
x = broadcast_to(array({1.0f, 2.0f}), {4, 2});
|
||||
y = slice(x, {0, 0}, {1, 2}, {1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 2);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
y = slice(x, {1, 0}, {2, 2}, {1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 2);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2});
|
||||
y = slice(x, {0, 0}, {2, 2}, {1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 4);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, false);
|
||||
|
||||
y = slice(transpose(x), {0, 0}, {2, 2}, {1, 1});
|
||||
eval(y);
|
||||
CHECK_EQ(y.data_size(), 4);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
|
||||
x = ones({2, 4});
|
||||
auto out = split(x, 2);
|
||||
eval(out);
|
||||
for (auto y : out) {
|
||||
CHECK_EQ(y.data_size(), 4);
|
||||
CHECK_EQ(y.flags().contiguous, true);
|
||||
CHECK_EQ(y.flags().row_contiguous, true);
|
||||
CHECK_EQ(y.flags().col_contiguous, true);
|
||||
}
|
||||
out = split(x, 4, 1);
|
||||
eval(out);
|
||||
for (auto y : out) {
|
||||
CHECK_EQ(y.flags().contiguous, false);
|
||||
CHECK_EQ(y.flags().row_contiguous, false);
|
||||
CHECK_EQ(y.flags().col_contiguous, false);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test array iteration") {
|
||||
// Dim 0 arrays
|
||||
auto arr = array(1);
|
||||
CHECK_THROWS(arr.begin());
|
||||
|
||||
// Iterated arrays are read only
|
||||
CHECK(std::is_const_v<decltype(*arr.begin())>);
|
||||
|
||||
arr = array({1, 2, 3, 4, 5});
|
||||
int i = 0;
|
||||
for (auto a : arr) {
|
||||
i++;
|
||||
CHECK_EQ(a.item<int>(), i);
|
||||
}
|
||||
CHECK_EQ(i, 5);
|
||||
|
||||
arr = array({1, 2, 3, 4}, {2, 2});
|
||||
CHECK(array_equal(*arr.begin(), array({1, 2})).item<bool>());
|
||||
CHECK(array_equal(*(arr.begin() + 1), array({3, 4})).item<bool>());
|
||||
CHECK_EQ(arr.begin() + 2, arr.end());
|
||||
}
|
||||
|
||||
TEST_CASE("test array shared buffer") {
|
||||
std::vector<int> shape = {2, 2};
|
||||
int n_elem = shape[0] * shape[1];
|
||||
|
||||
allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float));
|
||||
void* buf_b_ptr = buf_b.raw_ptr();
|
||||
float* float_buf_b = (float*)buf_b_ptr;
|
||||
|
||||
for (int i = 0; i < n_elem; i++) {
|
||||
float_buf_b[i] = 2.;
|
||||
}
|
||||
|
||||
CHECK_EQ(float_buf_b[0], ((float*)buf_b_ptr)[0]);
|
||||
|
||||
auto deleter = [float_buf_b](allocator::Buffer buf) {
|
||||
CHECK_EQ(float_buf_b, (float*)buf.raw_ptr());
|
||||
CHECK_EQ(float_buf_b[0], ((float*)buf.raw_ptr())[0]);
|
||||
allocator::free(buf);
|
||||
};
|
||||
|
||||
array a = ones(shape, float32);
|
||||
array b = array(buf_b, shape, float32, deleter);
|
||||
|
||||
eval(a + b);
|
||||
}
|
1192
tests/autograd_tests.cpp
Normal file
1192
tests/autograd_tests.cpp
Normal file
File diff suppressed because it is too large
Load Diff
33
tests/device_tests.cpp
Normal file
33
tests/device_tests.cpp
Normal file
@ -0,0 +1,33 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test device placement") {
|
||||
auto device = default_device();
|
||||
Device d = metal::is_available() ? Device::gpu : Device::cpu;
|
||||
if (std::getenv("DEVICE") == nullptr) {
|
||||
CHECK_EQ(device, d);
|
||||
}
|
||||
|
||||
array x(1.0f);
|
||||
array y(1.0f);
|
||||
auto z = add(x, y, default_device());
|
||||
if (metal::is_available()) {
|
||||
z = add(x, y, Device::gpu);
|
||||
z = add(x, y, Device(Device::gpu, 0));
|
||||
} else {
|
||||
CHECK_THROWS_AS(set_default_device(Device::gpu), std::invalid_argument);
|
||||
CHECK_THROWS_AS(add(x, y, Device::gpu), std::invalid_argument);
|
||||
}
|
||||
|
||||
// Set the default device to the CPU
|
||||
set_default_device(Device::cpu);
|
||||
CHECK_EQ(default_device(), Device::cpu);
|
||||
|
||||
// Revert
|
||||
set_default_device(device);
|
||||
}
|
97
tests/eval_tests.cpp
Normal file
97
tests/eval_tests.cpp
Normal file
@ -0,0 +1,97 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test eval") {
|
||||
{
|
||||
array x(1.0);
|
||||
array y(1);
|
||||
array z(true);
|
||||
eval({x, y, z});
|
||||
CHECK_EQ(x.item<float>(), 1.0);
|
||||
}
|
||||
|
||||
{
|
||||
array x(1.0);
|
||||
array y = ones({2, 2});
|
||||
array z(true);
|
||||
eval({x, y, z});
|
||||
CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test eval multiple") {
|
||||
auto x = ones({10, 10});
|
||||
auto y = ones({10, 10});
|
||||
eval({x, y});
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
|
||||
auto a = x + y;
|
||||
auto b = x - y;
|
||||
eval({a, b});
|
||||
CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
|
||||
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
||||
|
||||
x = ones({10, 10});
|
||||
y = ones({10, 10});
|
||||
eval(x, y);
|
||||
CHECK(array_equal(x, y).item<bool>());
|
||||
|
||||
a = x + y;
|
||||
b = x - y;
|
||||
eval(a, b);
|
||||
CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
|
||||
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test eval with tracer") {
|
||||
auto x = array(1);
|
||||
x.set_tracer(true);
|
||||
|
||||
// Ok, x is not a node
|
||||
eval(x);
|
||||
|
||||
x = ones({2, 3});
|
||||
x.set_tracer(true);
|
||||
CHECK_THROWS(eval(x));
|
||||
|
||||
// Ok retain_graph=true
|
||||
eval({x}, true);
|
||||
|
||||
// Make sure all arguments are checked
|
||||
auto y = ones({2, 3});
|
||||
CHECK_THROWS(eval(x, y));
|
||||
}
|
||||
|
||||
TEST_CASE("test eval graph retention") {
|
||||
auto x = array(1);
|
||||
auto y = array(2);
|
||||
auto z = x + y;
|
||||
eval({z}, true);
|
||||
CHECK(z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK_EQ(z.item<int>(true), 3);
|
||||
CHECK(z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
|
||||
CHECK_EQ(z.item<int>(), 3);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
|
||||
z = x + y;
|
||||
auto a = z + x;
|
||||
auto b = a + y;
|
||||
eval({b}, true);
|
||||
CHECK(z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(a.has_primitive());
|
||||
CHECK(a.is_evaled());
|
||||
|
||||
eval({b}, false);
|
||||
CHECK(!z.has_primitive());
|
||||
CHECK(z.is_evaled());
|
||||
CHECK(!a.has_primitive());
|
||||
CHECK(a.is_evaled());
|
||||
}
|
81
tests/load_tests.cpp
Normal file
81
tests/load_tests.cpp
Normal file
@ -0,0 +1,81 @@
|
||||
#include <filesystem>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
std::string get_temp_file(const std::string& name) {
|
||||
return std::filesystem::temp_directory_path().append(name);
|
||||
}
|
||||
|
||||
TEST_CASE("test single array serialization") {
|
||||
// Basic test
|
||||
{
|
||||
auto a = random::uniform(-5.f, 5.f, {2, 5, 12}, float32);
|
||||
|
||||
std::string file_path = get_temp_file("test_arr.npy");
|
||||
|
||||
save(file_path, a);
|
||||
auto b = load(file_path);
|
||||
|
||||
CHECK_EQ(a.dtype(), b.dtype());
|
||||
CHECK_EQ(a.shape(), b.shape());
|
||||
CHECK(array_equal(a, b).item<bool>());
|
||||
}
|
||||
|
||||
// Other shapes
|
||||
{
|
||||
auto a = random::uniform(
|
||||
-5.f,
|
||||
5.f,
|
||||
{
|
||||
1,
|
||||
},
|
||||
float32);
|
||||
|
||||
std::string file_path = get_temp_file("test_arr_0.npy");
|
||||
|
||||
save(file_path, a);
|
||||
auto b = load(file_path);
|
||||
|
||||
CHECK_EQ(a.dtype(), b.dtype());
|
||||
CHECK_EQ(a.shape(), b.shape());
|
||||
CHECK(array_equal(a, b).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = random::uniform(
|
||||
-5.f,
|
||||
5.f,
|
||||
{
|
||||
46,
|
||||
},
|
||||
float32);
|
||||
|
||||
std::string file_path = get_temp_file("test_arr_1.npy");
|
||||
|
||||
save(file_path, a);
|
||||
auto b = load(file_path);
|
||||
|
||||
CHECK_EQ(a.dtype(), b.dtype());
|
||||
CHECK_EQ(a.shape(), b.shape());
|
||||
CHECK(array_equal(a, b).item<bool>());
|
||||
}
|
||||
|
||||
{
|
||||
auto a = random::uniform(-5.f, 5.f, {5, 2, 1, 3, 4}, float32);
|
||||
|
||||
std::string file_path = get_temp_file("test_arr_2.npy");
|
||||
|
||||
save(file_path, a);
|
||||
auto b = load(file_path);
|
||||
|
||||
CHECK_EQ(a.dtype(), b.dtype());
|
||||
CHECK_EQ(a.shape(), b.shape());
|
||||
CHECK(array_equal(a, b).item<bool>());
|
||||
}
|
||||
}
|
119
tests/scheduler_tests.cpp
Normal file
119
tests/scheduler_tests.cpp
Normal file
@ -0,0 +1,119 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test stream management") {
|
||||
auto s1 = default_stream(default_device());
|
||||
CHECK_EQ(s1.device, default_device());
|
||||
|
||||
auto s2 = new_stream(default_device());
|
||||
CHECK_EQ(s2.device, default_device());
|
||||
CHECK_NE(s1, s2);
|
||||
|
||||
// Check that default streams have the correct devices
|
||||
if (metal::is_available()) {
|
||||
auto s_gpu = default_stream(Device::gpu);
|
||||
CHECK_EQ(s_gpu.device, Device::gpu);
|
||||
} else {
|
||||
CHECK_THROWS_AS(default_stream(Device::gpu), std::invalid_argument);
|
||||
}
|
||||
auto s_cpu = default_stream(Device::cpu);
|
||||
CHECK_EQ(s_cpu.device, Device::cpu);
|
||||
|
||||
s_cpu = new_stream(Device::cpu);
|
||||
CHECK_EQ(s_cpu.device, Device::cpu);
|
||||
|
||||
if (metal::is_available()) {
|
||||
auto s_gpu = new_stream(Device::gpu);
|
||||
CHECK_EQ(s_gpu.device, Device::gpu);
|
||||
} else {
|
||||
CHECK_THROWS_AS(new_stream(Device::gpu), std::invalid_argument);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test asynchronous launch") {
|
||||
auto s1 = default_stream(default_device());
|
||||
auto s2 = new_stream(default_device());
|
||||
|
||||
// Make sure streams execute asynchronously
|
||||
int x = 1;
|
||||
auto p1 = std::make_shared<std::promise<void>>();
|
||||
auto p2 = std::make_shared<std::promise<void>>();
|
||||
auto f1 = p1->get_future().share();
|
||||
auto f2 = p2->get_future().share();
|
||||
auto fn1 = [&x, p = std::move(p1)]() {
|
||||
x++;
|
||||
p->set_value();
|
||||
};
|
||||
auto fn2 = [&x, p = std::move(p2), f = std::move(f1)]() {
|
||||
f.wait();
|
||||
x *= 5;
|
||||
p->set_value();
|
||||
};
|
||||
|
||||
// fn2 is launched first and is waiting on fn1 but since
|
||||
// they are on different streams there is no deadlock.
|
||||
scheduler::enqueue(s2, std::move(fn2));
|
||||
scheduler::enqueue(s1, std::move(fn1));
|
||||
|
||||
f2.wait();
|
||||
|
||||
CHECK_EQ(x, 10);
|
||||
}
|
||||
|
||||
TEST_CASE("test stream placement") {
|
||||
auto s1 = default_stream(default_device());
|
||||
auto s2 = new_stream(default_device());
|
||||
|
||||
{
|
||||
// Wait on stream 1
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
auto f = p->get_future().share();
|
||||
scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); });
|
||||
|
||||
// Do some work on stream 2
|
||||
auto x = zeros({100}, float32, s2);
|
||||
auto y = ones({100}, float32, s2);
|
||||
auto z = add(x, y, s2);
|
||||
eval(z);
|
||||
p->set_value();
|
||||
}
|
||||
|
||||
{
|
||||
// Wait on stream 1
|
||||
auto p = std::make_shared<std::promise<void>>();
|
||||
auto f = p->get_future().share();
|
||||
scheduler::enqueue(s1, [f = std::move(f)]() { f.wait(); });
|
||||
|
||||
// Do some work on stream 2
|
||||
auto fn = [&s2](array a) { return add(a, add(a, a, s2), s2); };
|
||||
auto x = zeros({100}, s2);
|
||||
|
||||
// The whole vjp computation should happen
|
||||
// on the second stream otherwise this will hang.
|
||||
auto [out, dout] = vjp(fn, x, ones({100}, s2));
|
||||
|
||||
// The whole jvp computation should happen on the
|
||||
// second stream.
|
||||
std::tie(out, dout) = jvp(fn, x, ones({100}, s2));
|
||||
eval(out, dout);
|
||||
|
||||
p->set_value();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test scheduler races") {
|
||||
auto x = zeros({1});
|
||||
auto y = zeros({100});
|
||||
eval(x, y);
|
||||
auto a = exp(x);
|
||||
eval(a);
|
||||
a = exp(x);
|
||||
for (int i = 0; i < 10000; ++i) {
|
||||
y = exp(y);
|
||||
}
|
||||
eval(a, y);
|
||||
}
|
26
tests/utils_tests.cpp
Normal file
26
tests/utils_tests.cpp
Normal file
@ -0,0 +1,26 @@
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
TEST_CASE("test type promotion") {
|
||||
for (auto t : {bool_, uint32, int32, int64, float32}) {
|
||||
auto a = array(0, t);
|
||||
CHECK_EQ(result_type({a}), t);
|
||||
|
||||
std::vector<array> arrs = {array(0, t), array(0, t)};
|
||||
CHECK_EQ(result_type(arrs), t);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<array> arrs = {array(false), array(0, int32)};
|
||||
CHECK_EQ(result_type(arrs), int32);
|
||||
}
|
||||
|
||||
{
|
||||
std::vector<array> arrs = {array(0, int32), array(false), array(0.0f)};
|
||||
CHECK_EQ(result_type(arrs), float32);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user